Commit b7f1b050 authored by Neel Kant's avatar Neel Kant
Browse files

Lint whole repo

parent c99fa80c
...@@ -357,7 +357,6 @@ def _add_gpt2_args(parser): ...@@ -357,7 +357,6 @@ def _add_gpt2_args(parser):
return parser return parser
def add_data_args_(parser): def add_data_args_(parser):
"""Train/valid/test data arguments.""" """Train/valid/test data arguments."""
...@@ -367,6 +366,4 @@ def add_data_args_(parser): ...@@ -367,6 +366,4 @@ def add_data_args_(parser):
choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'], choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
help='Which data loader to use. Default varies by model.') help='Which data loader to use. Default varies by model.')
return parser return parser
...@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration, ...@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
directory = 'iter_{:07d}'.format(iteration) directory = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, directory, return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format( 'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \ mpu.get_model_parallel_rank() if mp_rank is None
else mp_rank), else mp_rank),
'model_optim_rng.pt') 'model_optim_rng.pt')
...@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler): ...@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'megatron.fp16.loss_scaler'] 'megatron.fp16.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None) sys.modules.pop('fp16.loss_scaler', None)
except: except BaseException:
print_rank_0('could not load the checkpoint') print_rank_0('could not load the checkpoint')
sys.exit() sys.exit()
...@@ -190,7 +190,7 @@ def load_checkpoint(model, optimizer, lr_scheduler): ...@@ -190,7 +190,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
try: try:
iteration = state_dict['iteration'] iteration = state_dict['iteration']
except KeyError: except KeyError:
try: # Backward compatible with older checkpoints try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters'] iteration = state_dict['total_iters']
except KeyError: except KeyError:
print_rank_0('A metadata file exists but unable to load ' print_rank_0('A metadata file exists but unable to load '
......
from . import indexed_dataset from . import indexed_dataset
...@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits. # Print stats about the splits.
print_rank_0(' > dataset split:') print_rank_0(' > dataset split:')
def print_split_stats(name, index): def print_split_stats(name, index):
print_rank_0(' {}:'.format(name)) print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} ' print_rank_0(' document indices in [{}, {}) total of {} '
...@@ -113,7 +114,6 @@ class BertDataset(Dataset): ...@@ -113,7 +114,6 @@ class BertDataset(Dataset):
# Dataset. # Dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
# Build the samples mapping. # Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset, self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix, data_prefix,
...@@ -133,11 +133,9 @@ class BertDataset(Dataset): ...@@ -133,11 +133,9 @@ class BertDataset(Dataset):
self.mask_id = tokenizer.mask self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx] start_index, end_index, seq_length = self.samples_mapping[idx]
...@@ -148,7 +146,7 @@ class BertDataset(Dataset): ...@@ -148,7 +146,7 @@ class BertDataset(Dataset):
# python randint is inclusive whereas the numpy one is exclusive. # python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx)) np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length, return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding self.max_seq_length, # needed for padding
self.vocab_id_list, self.vocab_id_list,
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
...@@ -192,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size): ...@@ -192,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size):
splits = splits[:3] splits = splits[:3]
splits_sum = sum(splits) splits_sum = sum(splits)
assert splits_sum > 0.0 assert splits_sum > 0.0
splits = [split/splits_sum for split in splits] splits = [split / splits_sum for split in splits]
splits_index = [0] splits_index = [0]
for index, split in enumerate(splits): for index, split in enumerate(splits):
splits_index.append(splits_index[index] + splits_index.append(splits_index[index] +
...@@ -254,7 +252,7 @@ def get_samples_mapping_(indexed_dataset, ...@@ -254,7 +252,7 @@ def get_samples_mapping_(indexed_dataset,
indexed_dataset.sizes, indexed_dataset.sizes,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
max_seq_length-3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose)
......
...@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits. # Print stats about the splits.
print_rank_0(' > dataset split:') print_rank_0(' > dataset split:')
def print_split_stats(name, index): def print_split_stats(name, index):
print_rank_0(' {}:'.format(name)) print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} ' print_rank_0(' document indices in [{}, {}) total of {} '
...@@ -54,7 +55,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -54,7 +55,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name): def build_dataset(index, name):
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index+1], documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32) step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix, dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset, documents, indexed_dataset,
...@@ -102,21 +103,19 @@ class GPT2Dataset(torch.utils.data.Dataset): ...@@ -102,21 +103,19 @@ class GPT2Dataset(torch.utils.data.Dataset):
self.name, data_prefix, documents, self.indexed_dataset.sizes, self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed) num_samples, seq_length, seed)
def __len__(self): def __len__(self):
# -1 is due to data structure used to retieve the index: # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1]) # sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1 return self.sample_idx.shape[0] - 1
def __getitem__(self, idx): def __getitem__(self, idx):
# Get the shuffled index. # Get the shuffled index.
idx = self.shuffle_idx[idx] idx = self.shuffle_idx[idx]
# Start and end documents and offsets. # Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0] doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx+1][0] doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1] offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx+1][1] offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk. # If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l: if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
...@@ -127,18 +126,17 @@ class GPT2Dataset(torch.utils.data.Dataset): ...@@ -127,18 +126,17 @@ class GPT2Dataset(torch.utils.data.Dataset):
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)] offset=offset_f)]
# Loop over all in between documents and add the entire document. # Loop over all in between documents and add the entire document.
for i in range(doc_index_f+1, doc_index_l): for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document. # And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get( sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l], self.doc_idx[doc_index_l],
length=offset_l+1)) length=offset_l + 1))
sample = np.concatenate(sample_list) sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)} return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes, def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples, seq_length, seed): num_samples, seq_length, seed):
"""Build doc-idx, sample-idx, and shuffle-idx. """Build doc-idx, sample-idx, and shuffle-idx.
...@@ -185,7 +183,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -185,7 +183,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert sizes.dtype == np.int32 assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch) num_epochs, tokens_per_epoch)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch) # num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True) np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping ' print_rank_0(' > elasped time to build and save sample-idx mapping '
...@@ -194,7 +192,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -194,7 +192,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time() start_time = time.time()
# -1 is due to data structure used to retieve the index: # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1]) # sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng) shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping' print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time)) ' (seconds): {:4f}'.format(time.time() - start_time))
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import torch import torch
from megatron import print_rank_0 from megatron import print_rank_0
def __best_fitting_dtype(vocab_size=None): def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500: if vocab_size is not None and vocab_size < 65500:
return np.uint16 return np.uint16
...@@ -109,13 +110,15 @@ def index_file_path(prefix_path): ...@@ -109,13 +110,15 @@ def index_file_path(prefix_path):
def data_file_path(prefix_path): def data_file_path(prefix_path):
return prefix_path + '.bin' return prefix_path + '.bin'
def create_doc_idx(sizes): def create_doc_idx(sizes):
doc_idx = [0] doc_idx = [0]
for i, s in enumerate(sizes): for i, s in enumerate(sizes):
if s == 0: if s == 0:
doc_idx.append(i+1) doc_idx.append(i + 1)
return doc_idx return doc_idx
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset""" """Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00' _HDR_MAGIC = b'TNTIDX\x00\x00'
...@@ -155,7 +158,7 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -155,7 +158,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if self.data_file: if self.data_file:
self.data_file.close() self.data_file.close()
#@lru_cache(maxsize=8) # @lru_cache(maxsize=8)
def __getitem__(self, idx): def __getitem__(self, idx):
if not self.data_file: if not self.data_file:
self.read_data(self.path) self.read_data(self.path)
...@@ -235,7 +238,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -235,7 +238,7 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close() self.data_file.close()
self.data_file = None self.data_file = None
#@lru_cache(maxsize=8) # @lru_cache(maxsize=8)
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
i = idx i = idx
...@@ -399,13 +402,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -399,13 +402,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...") print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print_rank_0(" reading pointers...") print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes) offset=offset + self._sizes.nbytes)
print_rank_0(" reading document index...") print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes) offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self): def __del__(self):
self._bin_buffer_mmap._mmap.close() self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap del self._bin_buffer_mmap
...@@ -464,7 +472,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -464,7 +472,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self._index) return len(self._index)
#@lru_cache(maxsize=8) # @lru_cache(maxsize=8)
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
ptr, size = self._index[idx] ptr, size = self._index[idx]
......
...@@ -81,6 +81,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -81,6 +81,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
sampler level. This allows wrapping of arbitrary data samplers sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch (sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.""" sampler."""
def __init__(self, sampler, batch_size, drop_last, rank=-1, def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False): world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, super(DistributedBatchSampler, self).__init__(sampler, batch_size,
...@@ -120,7 +121,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -120,7 +121,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def data_iterator(self, _iter, wrap_around=False): def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around""" """iterates through data and handles wrap around"""
for i, idx in enumerate(_iter): for i, idx in enumerate(_iter):
if i < self.wrap_around%self.batch_size: if i < self.wrap_around % self.batch_size:
continue continue
if wrap_around: if wrap_around:
self.wrap_around += 1 self.wrap_around += 1
...@@ -129,6 +130,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -129,6 +130,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch): def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch""" """extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size start = self.rank * self.batch_size // self.world_size
end = (self.rank+1)*self.batch_size//self.world_size end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end] return batch[start:end]
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# put some code used during development and manual testing of # put some code used during development and manual testing of
# indexed_dataset. # indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse import argparse
import os import os
import sys import sys
...@@ -11,8 +13,6 @@ import torch ...@@ -11,8 +13,6 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../")) sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
def test_indexed_dataset(args): def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
...@@ -23,12 +23,12 @@ def test_indexed_dataset(args): ...@@ -23,12 +23,12 @@ def test_indexed_dataset(args):
if ds.supports_prefetch: if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small) # just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds))) ds.prefetch(range(len(ds)))
if args.count > len(ds.doc_idx)-1: if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx)-1 args.count = len(ds.doc_idx) - 1
for i in range(args.count): for i in range(args.count):
start = ds.doc_idx[i] start = ds.doc_idx[i]
end = ds.doc_idx[i+1] end = ds.doc_idx[i + 1]
ids = ds[start:end] ids = ds[start:end]
print(f"Document {i}:") print(f"Document {i}:")
print("--------------") print("--------------")
...@@ -39,6 +39,7 @@ def test_indexed_dataset(args): ...@@ -39,6 +39,7 @@ def test_indexed_dataset(args):
print(text) print(text)
print("---") print("---")
def test_indexed_dataset_get(args): def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
...@@ -46,19 +47,19 @@ def test_indexed_dataset_get(args): ...@@ -46,19 +47,19 @@ def test_indexed_dataset_get(args):
print(f"size: {size}") print(f"size: {size}")
full = ds.get(0) full = ds.get(0)
print(full) print(full)
#print(tokenizer.detokenize(full.data.tolist())) # print(tokenizer.detokenize(full.data.tolist()))
print("---") print("---")
end = ds.get(0, offset=size-10) end = ds.get(0, offset=size - 10)
print(end) print(end)
#print(tokenizer.detokenize(end.data.tolist())) # print(tokenizer.detokenize(end.data.tolist()))
start = ds.get(0, length=10) start = ds.get(0, length=10)
print(start) print(start)
#print(tokenizer.detokenize(start.data.tolist())) # print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8) part = ds.get(0, offset=2, length=8)
print(part) print(part)
#print(tokenizer.detokenize(part.data.tolist())) # print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args): # def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) # # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
...@@ -77,6 +78,7 @@ def test_indexed_dataset_get(args): ...@@ -77,6 +78,7 @@ def test_indexed_dataset_get(args):
# if i >= args.count-1: # if i >= args.count-1:
# exit() # exit()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--data', type=str, help='prefix to data files')
...@@ -118,5 +120,6 @@ def main(): ...@@ -118,5 +120,6 @@ def main():
# test_albert_dataset(args) # test_albert_dataset(args)
test_indexed_dataset_get(args) test_indexed_dataset_get(args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -28,21 +28,24 @@ TRAIN_DATA = 0 ...@@ -28,21 +28,24 @@ TRAIN_DATA = 0
VAL_DATA = 1 VAL_DATA = 1
TEST_DATA = 2 TEST_DATA = 2
def should_split(split): def should_split(split):
""" """
given split proportions checks if should split given split proportions checks if should split
Examples: Examples:
>>> should_split([10,0,0]) >>> should_split([10,0,0])
False False
>>> should_split([1,.1,.2]) >>> should_split([1,.1,.2])
True True
""" """
return max(split)/sum(split) != 1. return max(split) / sum(split) != 1.
def get_ext(path): def get_ext(path):
"""gets path extension""" """gets path extension"""
return os.path.splitext(path)[1] return os.path.splitext(path)[1]
def get_dataset(path, **kwargs): def get_dataset(path, **kwargs):
"""gets dataset object based on keyword args and file at `path`""" """gets dataset object based on keyword args and file at `path`"""
if supported_corpus(path): if supported_corpus(path):
...@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs): ...@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs):
elif ext in ['.csv', '.tsv']: elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs) text = csv_dataset(path, **kwargs)
else: else:
raise NotImplementedError('data file type %s is not supported'%(ext)) raise NotImplementedError('data file type %s is not supported' % (ext))
return text return text
def supported_corpus(corpus_name): def supported_corpus(corpus_name):
"""checks if corpus name is defined in `corpora.py`""" """checks if corpus name is defined in `corpora.py`"""
return corpus_name in corpora.NAMED_CORPORA return corpus_name in corpora.NAMED_CORPORA
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.], def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None, delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None, tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs): parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options""" """function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str): if isinstance(process_fn, str):
...@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if non_binary_cols is not None: if non_binary_cols is not None:
# multilabel dataset support (only for csvs) # multilabel dataset support (only for csvs)
label_key = non_binary_cols label_key = non_binary_cols
def get_dataset_from_path(path_): def get_dataset_from_path(path_):
if lazy: if lazy:
# get lazily loaded dataset # get lazily loaded dataset
...@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'): if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist # create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose) delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data') make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes # This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model # device_index=rank which is not the case for model
...@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
else: else:
# get dataset # get dataset
text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn) delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
return text return text
# get one or multiple datasets and concatenate # get one or multiple datasets and concatenate
if isinstance(path, str): if isinstance(path, str):
...@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds = ConcatDataset(datasets) ds = ConcatDataset(datasets)
# make tokenizer for dataset # make tokenizer for dataset
if tokenizer is None: if tokenizer is None:
tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type, tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, **kwargs) pad_token, character_converage, **kwargs)
ds_type = '' ds_type = ''
if 'ds_type' in kwargs: if 'ds_type' in kwargs:
...@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if 'bert' in ds_type.lower(): if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds] ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2': elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds] ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else: else:
......
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from megatron import data_utils from megatron import data_utils
from megatron import mpu from megatron import mpu
class DataConfig: class DataConfig:
def __init__(self, defaults={}): def __init__(self, defaults={}):
...@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args): ...@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args):
shuffle = args.shuffle shuffle = args.shuffle
if shuffle: if shuffle:
sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters) sampler = data_utils.samplers.RandomSampler(
dataset, replacement=True, num_samples=batch_size * args.train_iters)
else: else:
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
world_size = torch.distributed.get_world_size( world_size = torch.distributed.get_world_size(
...@@ -204,6 +206,7 @@ def make_loaders(args): ...@@ -204,6 +206,7 @@ def make_loaders(args):
return (train, valid, test), tokenizer return (train, valid, test), tokenizer
def get_split(args): def get_split(args):
""" """
Get dataset splits from comma separated string list Get dataset splits from comma separated string list
...@@ -217,7 +220,7 @@ def get_split(args): ...@@ -217,7 +220,7 @@ def get_split(args):
splits = [float(args.split)] splits = [float(args.split)]
split_total = sum(splits) split_total = sum(splits)
if split_total < 1.: if split_total < 1.:
splits.append(1-split_total) splits.append(1 - split_total)
while len(splits) < 3: while len(splits) < 3:
splits.append(0.) splits.append(0.)
splits = splits[:3] splits = splits[:3]
...@@ -226,10 +229,10 @@ def get_split(args): ...@@ -226,10 +229,10 @@ def get_split(args):
if args.test_data is not None: if args.test_data is not None:
splits[2] = 0. splits[2] = 0.
final_sum = sum(splits) final_sum = sum(splits)
return [s/final_sum for s in splits] return [s / final_sum for s in splits]
def configure_data():
def configure_data():
"""add cmdline flags for configuring datasets""" """add cmdline flags for configuring datasets"""
# These are options that are used by data_utils, but are either # These are options that are used by data_utils, but are either
# deprecated or not meant to be exposed to the command line user. # deprecated or not meant to be exposed to the command line user.
......
...@@ -16,43 +16,46 @@ ...@@ -16,43 +16,46 @@
from .datasets import json_dataset, csv_dataset from .datasets import json_dataset, csv_dataset
import os import os
class wikipedia(json_dataset): class wikipedia(json_dataset):
""" """
dataset for wikipedia with arguments configured for convenience dataset for wikipedia with arguments configured for convenience
command line usage: `--train-data wikipedia` command line usage: `--train-data wikipedia`
""" """
PATH = 'data/wikipedia/wikidump_lines.json' PATH = 'data/wikipedia/wikidump_lines.json'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py" assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(wikipedia.PATH), \ def __init__(self, **kwargs):
wikipedia.assert_str assert os.path.exists(wikipedia.PATH), \
if not kwargs: wikipedia.assert_str
kwargs = {} if not kwargs:
kwargs['text_key'] = 'text' kwargs = {}
kwargs['loose_json'] = True kwargs['text_key'] = 'text'
super(wikipedia, self).__init__(wikipedia.PATH, **kwargs) kwargs['loose_json'] = True
super(wikipedia, self).__init__(wikipedia.PATH, **kwargs)
class webtext(json_dataset): class webtext(json_dataset):
""" """
dataset for webtext with arguments configured for convenience dataset for webtext with arguments configured for convenience
command line usage: `--train-data webtext` command line usage: `--train-data webtext`
""" """
PATH = 'data/webtext/data.json' PATH = 'data/webtext/data.json'
assert_str = "make sure to set PATH for webtext data_utils/corpora.py" assert_str = "make sure to set PATH for webtext data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(webtext.PATH), \ def __init__(self, **kwargs):
webtext.assert_str assert os.path.exists(webtext.PATH), \
if not kwargs: webtext.assert_str
kwargs = {} if not kwargs:
kwargs['text_key'] = 'text' kwargs = {}
kwargs['loose_json'] = True kwargs['text_key'] = 'text'
super(webtext, self).__init__(webtext.PATH, **kwargs) kwargs['loose_json'] = True
super(webtext, self).__init__(webtext.PATH, **kwargs)
NAMED_CORPORA = { NAMED_CORPORA = {
'wikipedia': wikipedia, 'wikipedia': wikipedia,
'webtext': webtext, 'webtext': webtext,
} }
...@@ -34,6 +34,7 @@ from nltk import tokenize ...@@ -34,6 +34,7 @@ from nltk import tokenize
from .lazy_loader import lazy_array_loader, exists_lazy, make_lazy from .lazy_loader import lazy_array_loader, exists_lazy, make_lazy
from .tokenization import Tokenization from .tokenization import Tokenization
class ConcatDataset(data.Dataset): class ConcatDataset(data.Dataset):
""" """
Dataset to concatenate multiple datasets. Dataset to concatenate multiple datasets.
...@@ -57,7 +58,8 @@ class ConcatDataset(data.Dataset): ...@@ -57,7 +58,8 @@ class ConcatDataset(data.Dataset):
super(ConcatDataset, self).__init__() super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable' assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets) self.datasets = list(datasets)
self.is_lazy = sum([isinstance(ds, lazy_array_loader) for ds in self.datasets]) == len(self.datasets) self.is_lazy = sum([isinstance(ds, lazy_array_loader)
for ds in self.datasets]) == len(self.datasets)
self.cumulative_sizes = self.cumsum(self.datasets) self.cumulative_sizes = self.cumsum(self.datasets)
self._X = None self._X = None
self._Y = None self._Y = None
...@@ -90,7 +92,8 @@ class ConcatDataset(data.Dataset): ...@@ -90,7 +92,8 @@ class ConcatDataset(data.Dataset):
self._lens.extend(data.lens) self._lens.extend(data.lens)
else: else:
for data in self.datasets: for data in self.datasets:
self._lens.extend([len(d['text']) if isinstance(d, dict) else len(d) for d in data]) self._lens.extend([len(d['text']) if isinstance(
d, dict) else len(d) for d in data])
return self._lens return self._lens
@property @property
...@@ -116,6 +119,7 @@ class ConcatDataset(data.Dataset): ...@@ -116,6 +119,7 @@ class ConcatDataset(data.Dataset):
"cumulative_sizes", DeprecationWarning, stacklevel=2) "cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes return self.cumulative_sizes
class SplitDataset(data.Dataset): class SplitDataset(data.Dataset):
""" """
Dataset wrapper to access a subset of another dataset. Dataset wrapper to access a subset of another dataset.
...@@ -126,6 +130,7 @@ class SplitDataset(data.Dataset): ...@@ -126,6 +130,7 @@ class SplitDataset(data.Dataset):
ds (Dataset or array-like): List of datasets to be subindexed ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset split_inds (1D array-like): List of indices part of subset
""" """
def __init__(self, ds, split_inds, **kwargs): def __init__(self, ds, split_inds, **kwargs):
self.split_inds = list(split_inds) self.split_inds = list(split_inds)
self.wrapped_data = ds self.wrapped_data = ds
...@@ -163,7 +168,8 @@ class SplitDataset(data.Dataset): ...@@ -163,7 +168,8 @@ class SplitDataset(data.Dataset):
for idx in self.split_inds: for idx in self.split_inds:
yield self.wrapped_data[idx] yield self.wrapped_data[idx]
def split_ds(ds, split=[.8,.2,.0], shuffle=True):
def split_ds(ds, split=[.8, .2, .0], shuffle=True):
""" """
Split a dataset into subsets given proportions of how Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split. much to allocate per split. If a split is 0% returns None for that split.
...@@ -184,18 +190,19 @@ def split_ds(ds, split=[.8,.2,.0], shuffle=True): ...@@ -184,18 +190,19 @@ def split_ds(ds, split=[.8,.2,.0], shuffle=True):
np.random.shuffle(inds) np.random.shuffle(inds)
start_idx = 0 start_idx = 0
residual_idx = 0 residual_idx = 0
rtn_ds = [None]*len(split) rtn_ds = [None] * len(split)
for i, f in enumerate(split): for i, f in enumerate(split):
if f != 0: if f != 0:
proportion = ds_len*split[i] proportion = ds_len * split[i]
residual_idx += proportion % 1 residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx) split_ = int(int(proportion) + residual_idx)
split_inds = inds[start_idx:start_idx+max(split_, 1)] split_inds = inds[start_idx:start_idx + max(split_, 1)]
rtn_ds[i] = SplitDataset(ds, split_inds) rtn_ds[i] = SplitDataset(ds, split_inds)
start_idx += split_ start_idx += split_
residual_idx %= 1 residual_idx %= 1
return rtn_ds return rtn_ds
class csv_dataset(data.Dataset): class csv_dataset(data.Dataset):
""" """
Class for loading datasets from csv files. Class for loading datasets from csv files.
...@@ -214,9 +221,10 @@ class csv_dataset(data.Dataset): ...@@ -214,9 +221,10 @@ class csv_dataset(data.Dataset):
X (list): all strings from the csv file X (list): all strings from the csv file
Y (np.ndarray): labels to train with Y (np.ndarray): labels to train with
""" """
def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',', def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',',
binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label', binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label',
**kwargs): **kwargs):
self.is_lazy = False self.is_lazy = False
self.preprocess_fn = preprocess_fn self.preprocess_fn = preprocess_fn
self.SetTokenizer(tokenizer) self.SetTokenizer(tokenizer)
...@@ -229,7 +237,6 @@ class csv_dataset(data.Dataset): ...@@ -229,7 +237,6 @@ class csv_dataset(data.Dataset):
if '.tsv' in self.path: if '.tsv' in self.path:
self.delim = '\t' self.delim = '\t'
self.X = [] self.X = []
self.Y = [] self.Y = []
try: try:
...@@ -239,7 +246,7 @@ class csv_dataset(data.Dataset): ...@@ -239,7 +246,7 @@ class csv_dataset(data.Dataset):
else: else:
cols += [label_key] cols += [label_key]
data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1') data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1')
except: except BaseException:
data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1') data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1')
data = data.dropna(axis=0) data = data.dropna(axis=0)
...@@ -248,7 +255,7 @@ class csv_dataset(data.Dataset): ...@@ -248,7 +255,7 @@ class csv_dataset(data.Dataset):
try: try:
self.Y = data[label_key].values self.Y = data[label_key].values
except Exception as e: except Exception as e:
self.Y = np.ones(len(self.X))*-1 self.Y = np.ones(len(self.X)) * -1
if binarize_sent: if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent) self.Y = binarize_labels(self.Y, hard=binarize_sent)
...@@ -295,23 +302,25 @@ class csv_dataset(data.Dataset): ...@@ -295,23 +302,25 @@ class csv_dataset(data.Dataset):
write the metrics, text, and labels to a csv file write the metrics, text, and labels to a csv file
""" """
if path is None: if path is None:
path = self.path+'.results' path = self.path + '.results'
print('generating csv at ' + path) print('generating csv at ' + path)
with open(path, 'w') as csvfile: with open(path, 'w') as csvfile:
c = csv.writer(csvfile, delimiter=self.delim) c = csv.writer(csvfile, delimiter=self.delim)
if writer_gen is not None: if writer_gen is not None:
#if first item of generator is a header of what the metrics mean then write header to csv file # if first item of generator is a header of what the metrics mean then
# write header to csv file
if not skip_header: if not skip_header:
header = (self.label_key,)+tuple(next(writer_gen))+(self.text_key,) header = (self.label_key,) + tuple(next(writer_gen)) + (self.text_key,)
c.writerow(header) c.writerow(header)
for i, row in enumerate(writer_gen): for i, row in enumerate(writer_gen):
row = (self.Y[i],)+tuple(row)+(self.X[i],) row = (self.Y[i],) + tuple(row) + (self.X[i],)
c.writerow(row) c.writerow(row)
else: else:
c.writerow([self.label_key, self.text_key]) c.writerow([self.label_key, self.text_key])
for row in zip(self.Y, self.X): for row in zip(self.Y, self.X):
c.writerow(row) c.writerow(row)
class json_dataset(data.Dataset): class json_dataset(data.Dataset):
""" """
Class for loading datasets from a json dump. Class for loading datasets from a json dump.
...@@ -327,8 +336,9 @@ class json_dataset(data.Dataset): ...@@ -327,8 +336,9 @@ class json_dataset(data.Dataset):
all_strs (list): list of all strings from the dataset all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it) all_labels (list): list of all labels from the dataset (if they have it)
""" """
def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False, def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False,
text_key='sentence', label_key='label', loose_json=False, **kwargs): text_key='sentence', label_key='label', loose_json=False, **kwargs):
self.is_lazy = False self.is_lazy = False
self.preprocess_fn = preprocess_fn self.preprocess_fn = preprocess_fn
self.path = path self.path = path
...@@ -389,24 +399,25 @@ class json_dataset(data.Dataset): ...@@ -389,24 +399,25 @@ class json_dataset(data.Dataset):
write the metrics, text, and labels to a json file write the metrics, text, and labels to a json file
""" """
if path is None: if path is None:
path = self.path+'.results' path = self.path + '.results'
jsons = [] jsons = []
if writer_gen is not None: if writer_gen is not None:
#if first item of generator is a header of what the metrics mean then write header to csv file # if first item of generator is a header of what the metrics mean then
# write header to csv file
def gen_helper(): def gen_helper():
keys = {} keys = {}
keys[0] = self.label_key keys[0] = self.label_key
if not skip_header: if not skip_header:
for idx, k in enumerate(tuple(next(writer_gen))): for idx, k in enumerate(tuple(next(writer_gen))):
keys[idx+1] = k keys[idx + 1] = k
for i, row in enumerate(writer_gen): for i, row in enumerate(writer_gen):
if i == 0 and skip_header: if i == 0 and skip_header:
for idx, _ in enumerate(row): for idx, _ in enumerate(row):
keys[idx+1] = 'metric_%d'%(idx,) keys[idx + 1] = 'metric_%d' % (idx,)
j = {} j = {}
for idx, v in enumerate((self.Y[i],)+tuple(row)): for idx, v in enumerate((self.Y[i],) + tuple(row)):
k = keys[idx] k = keys[idx]
j[k] = v j[k] = v
yield j yield j
...@@ -453,6 +464,7 @@ class json_dataset(data.Dataset): ...@@ -453,6 +464,7 @@ class json_dataset(data.Dataset):
j[self.label_key] = -1 j[self.label_key] = -1
yield j yield j
class GPT2Dataset(data.Dataset): class GPT2Dataset(data.Dataset):
def __init__(self, ds, def __init__(self, ds,
...@@ -503,7 +515,7 @@ class GPT2Dataset(data.Dataset): ...@@ -503,7 +515,7 @@ class GPT2Dataset(data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
# init rng # init rng
rng = random.Random(idx) rng = random.Random(idx)
rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get possibly weighted random index from dataset # get possibly weighted random index from dataset
data_idx = self.get_weighted_samples(rng) data_idx = self.get_weighted_samples(rng)
...@@ -538,10 +550,10 @@ class GPT2Dataset(data.Dataset): ...@@ -538,10 +550,10 @@ class GPT2Dataset(data.Dataset):
else: else:
data_idx = (data_idx + 1) % self.ds_len data_idx = (data_idx + 1) % self.ds_len
tokens += self.getidx(data_idx) tokens += self.getidx(data_idx)
tokens = tokens[:(self.max_seq_len+1)] tokens = tokens[:(self.max_seq_len + 1)]
tokens = self.pad_seq(tokens) tokens = self.pad_seq(tokens)
return {'text': np.array(tokens),} return {'text': np.array(tokens), }
def getidx(self, data_idx): def getidx(self, data_idx):
data = self.ds[data_idx] data = self.ds[data_idx]
...@@ -556,7 +568,7 @@ class GPT2Dataset(data.Dataset): ...@@ -556,7 +568,7 @@ class GPT2Dataset(data.Dataset):
def pad_seq(self, seq): def pad_seq(self, seq):
total_tokens = self.max_seq_len + 1 total_tokens = self.max_seq_len + 1
num_pad_tokens = max(0, total_tokens - len(seq)) num_pad_tokens = max(0, total_tokens - len(seq))
seq += [self.tokenizer.get_command('pad').Id]*(num_pad_tokens) seq += [self.tokenizer.get_command('pad').Id] * (num_pad_tokens)
return seq return seq
def contains_sentence_end(self, tok): def contains_sentence_end(self, tok):
...@@ -569,6 +581,7 @@ class GPT2Dataset(data.Dataset): ...@@ -569,6 +581,7 @@ class GPT2Dataset(data.Dataset):
return True return True
return False return False
class bert_sentencepair_dataset(data.Dataset): class bert_sentencepair_dataset(data.Dataset):
""" """
Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair. Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
...@@ -581,7 +594,9 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -581,7 +594,9 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
""" """
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None,
short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
self.ds = ds self.ds = ds
self.ds_len = len(self.ds) self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer() self.tokenizer = self.ds.GetTokenizer()
...@@ -590,12 +605,12 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -590,12 +605,12 @@ class bert_sentencepair_dataset(data.Dataset):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.mask_lm_prob = mask_lm_prob self.mask_lm_prob = mask_lm_prob
if max_preds_per_seq is None: if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10 max_preds_per_seq = math.ceil(max_seq_len * mask_lm_prob / 10) * 10
self.max_preds_per_seq = max_preds_per_seq self.max_preds_per_seq = max_preds_per_seq
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size self.dataset_size = dataset_size
if self.dataset_size is None: if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1) self.dataset_size = self.ds_len * (self.ds_len - 1)
self.presplit_sentences = presplit_sentences self.presplit_sentences = presplit_sentences
if not self.presplit_sentences: if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk") nltk.download('punkt', download_dir="./nltk")
...@@ -607,7 +622,8 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -607,7 +622,8 @@ class bert_sentencepair_dataset(data.Dataset):
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens) lens = np.array(self.ds.lens)
else: else:
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds]) lens = np.array([len(d['text']) if isinstance(d, dict) else len(d)
for d in self.ds])
self.total_len = np.sum(lens) self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens)) self.weighting = list(accumulate(lens))
else: else:
...@@ -626,7 +642,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -626,7 +642,7 @@ class bert_sentencepair_dataset(data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair) # get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx) rng = random.Random(idx)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get seq length # get seq length
target_seq_length = self.max_seq_len target_seq_length = self.max_seq_len
short_seq = False short_seq = False
...@@ -639,15 +655,25 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -639,15 +655,25 @@ class bert_sentencepair_dataset(data.Dataset):
lena = 0 lena = 0
lenb = 0 lenb = 0
while (is_random_next is None) or (lena < 1) or (lenb < 1): while (is_random_next is None) or (lena < 1) or (lenb < 1):
tokensa, tokensb, is_random_next = self.create_random_sentencepair(target_seq_length, rng, np_rng) tokensa, tokensb, is_random_next = self.create_random_sentencepair(
target_seq_length, rng, np_rng)
lena = len(tokensa[0]) lena = len(tokensa[0])
lenb = len(tokensb[0]) lenb = len(tokensb[0])
# truncate sentence pair to max_seq_len # truncate sentence pair to max_seq_len
tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng) tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng)
# join sentence pair, mask, and pad # join sentence pair, mask, and pad
tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng) tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(
sample = {'text': np.array(tokens[0]), 'types': np.array(tokens[1]), 'is_random': int(is_random_next), 'mask': np.array(mask), 'mask_labels': np.array(mask_labels), 'pad_mask': np.array(pad_mask)} tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng)
sample = {
'text': np.array(
tokens[0]),
'types': np.array(
tokens[1]),
'is_random': int(is_random_next),
'mask': np.array(mask),
'mask_labels': np.array(mask_labels),
'pad_mask': np.array(pad_mask)}
return sample return sample
def sentence_split(self, document): def sentence_split(self, document):
...@@ -665,7 +691,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -665,7 +691,7 @@ class bert_sentencepair_dataset(data.Dataset):
"""tokenize sentence and get token types""" """tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num) str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens) token_types = [self.tokenizer.get_type(str_type).Id] * len(tokens)
return tokens, token_types return tokens, token_types
def get_doc(self, idx): def get_doc(self, idx):
...@@ -694,21 +720,22 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -694,21 +720,22 @@ class bert_sentencepair_dataset(data.Dataset):
# doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting) # doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
doc_a_idx = self.get_weighted_samples(np_rng) doc_a_idx = self.get_weighted_samples(np_rng)
else: else:
doc_a_idx = rng.randint(0, self.ds_len-1) doc_a_idx = rng.randint(0, self.ds_len - 1)
doc_a = self.sentence_split(self.get_doc(doc_a_idx)) doc_a = self.sentence_split(self.get_doc(doc_a_idx))
if not doc_a: if not doc_a:
doc_a = None doc_a = None
random_start_a = rng.randint(0, len(doc_a)-1) random_start_a = rng.randint(0, len(doc_a) - 1)
while random_start_a < len(doc_a): while random_start_a < len(doc_a):
sentence = doc_a[random_start_a] sentence = doc_a[random_start_a]
sentence, sentence_types = self.sentence_tokenize(sentence, 0, random_start_a == 0, random_start_a == len(doc_a)) sentence, sentence_types = self.sentence_tokenize(
sentence, 0, random_start_a == 0, random_start_a == len(doc_a))
curr_strs.append(sentence) curr_strs.append(sentence)
curr_str_types.append(sentence_types) curr_str_types.append(sentence_types)
curr_len += len(sentence) curr_len += len(sentence)
if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length: if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length:
break break
random_start_a = (random_start_a+1) random_start_a = (random_start_a + 1)
if curr_strs: if curr_strs:
num_a = 1 num_a = 1
...@@ -738,16 +765,17 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -738,16 +765,17 @@ class bert_sentencepair_dataset(data.Dataset):
if not doc_b: if not doc_b:
doc_b = None doc_b = None
random_start_b = rng.randint(0, len(doc_b)-1) random_start_b = rng.randint(0, len(doc_b) - 1)
while random_start_b < len(doc_b): while random_start_b < len(doc_b):
sentence_b = doc_b[random_start_b] sentence_b = doc_b[random_start_b]
new_b_tokens, new_b_types = self.sentence_tokenize(sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b)) new_b_tokens, new_b_types = self.sentence_tokenize(
sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b))
b_len += len(new_b_tokens) b_len += len(new_b_tokens)
tokens_b.extend(new_b_tokens) tokens_b.extend(new_b_tokens)
token_types_b.extend(new_b_types) token_types_b.extend(new_b_types)
if len(tokens_b) >= target_b_length: if len(tokens_b) >= target_b_length:
break break
random_start_b = (random_start_b+1) random_start_b = (random_start_b + 1)
else: else:
is_random_next = False is_random_next = False
for j in range(num_a, len(curr_strs)): for j in range(num_a, len(curr_strs)):
...@@ -812,13 +840,15 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -812,13 +840,15 @@ class bert_sentencepair_dataset(data.Dataset):
def pad_seq(self, seq): def pad_seq(self, seq):
"""helper function to pad sequence pair""" """helper function to pad sequence pair"""
num_pad = max(0, self.max_seq_len - len(seq)) num_pad = max(0, self.max_seq_len - len(seq))
pad_mask = [0] * len(seq) + [1] * num_pad pad_mask = [0] * len(seq) + [1] * num_pad
seq += [self.tokenizer.get_command('pad').Id] * num_pad seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b): def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id] tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command(
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]] 'sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + \
[token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types return tokens, token_types
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng): def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
...@@ -833,7 +863,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -833,7 +863,7 @@ class bert_sentencepair_dataset(data.Dataset):
len_a = len(tokens_a) len_a = len(tokens_a)
len_b = len(tokens_b) len_b = len(tokens_b)
cand_indices = [idx+1 for idx in range(len_a)] + [idx+2+len_a for idx in range(len_b)] cand_indices = [idx + 1 for idx in range(len_a)] + [idx + 2 + len_a for idx in range(len_b)]
rng.shuffle(cand_indices) rng.shuffle(cand_indices)
......
...@@ -169,7 +169,7 @@ def http_get(url, temp_file): ...@@ -169,7 +169,7 @@ def http_get(url, temp_file):
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total) progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024): for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
......
...@@ -22,11 +22,13 @@ from itertools import accumulate ...@@ -22,11 +22,13 @@ from itertools import accumulate
import torch import torch
from torch.multiprocessing import Lock from torch.multiprocessing import Lock
def get_lazy_path(path): def get_lazy_path(path):
""" """
Gets directory path where lazy files are stored. Gets directory path where lazy files are stored.
""" """
return os.path.splitext(path)[0]+'.lazy' return os.path.splitext(path)[0] + '.lazy'
def exists_lazy(path, data_type='data'): def exists_lazy(path, data_type='data'):
""" """
...@@ -37,10 +39,11 @@ def exists_lazy(path, data_type='data'): ...@@ -37,10 +39,11 @@ def exists_lazy(path, data_type='data'):
contents = os.listdir(get_lazy_path(path)) contents = os.listdir(get_lazy_path(path))
if data_type not in contents: if data_type not in contents:
return False return False
if data_type+'.len.pkl' not in contents: if data_type + '.len.pkl' not in contents:
return False return False
return True return True
def make_lazy(path, strs, data_type='data'): def make_lazy(path, strs, data_type='data'):
""" """
Make lazy version of `data_type` field of the file. Byte offsets Make lazy version of `data_type` field of the file. Byte offsets
...@@ -50,7 +53,7 @@ def make_lazy(path, strs, data_type='data'): ...@@ -50,7 +53,7 @@ def make_lazy(path, strs, data_type='data'):
if not os.path.exists(lazypath): if not os.path.exists(lazypath):
os.makedirs(lazypath) os.makedirs(lazypath)
datapath = os.path.join(lazypath, data_type) datapath = os.path.join(lazypath, data_type)
lenpath = os.path.join(lazypath, data_type+'.len.pkl') lenpath = os.path.join(lazypath, data_type + '.len.pkl')
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
with open(datapath, 'wb') as f: with open(datapath, 'wb') as f:
str_lens = [] str_lens = []
...@@ -67,28 +70,32 @@ def make_lazy(path, strs, data_type='data'): ...@@ -67,28 +70,32 @@ def make_lazy(path, strs, data_type='data'):
while not os.path.exists(lenpath): while not os.path.exists(lenpath):
time.sleep(1) time.sleep(1)
def split_strings(strings, start, chr_lens): def split_strings(strings, start, chr_lens):
""" """
Split strings based on string lengths and given start. Split strings based on string lengths and given start.
""" """
return [strings[i-start:j-start] for i, j in zip([start]+chr_lens[:-1], chr_lens)] return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)]
class ProcessorTokenizer: class ProcessorTokenizer:
""" """
callable class that runs a preprocessing, as well as tokenization step, callable class that runs a preprocessing, as well as tokenization step,
on input text. on input text.
""" """
def __init__(self, tokenizer, process_fn=None): def __init__(self, tokenizer, process_fn=None):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.process_fn = process_fn self.process_fn = process_fn
def __call__(self, string): def __call__(self, string):
if self.tokenizer is not None: if self.tokenizer is not None:
string = self.tokenizer(string, process_fn=self.process_fn) string = self.tokenizer(string, process_fn=self.process_fn)
elif self.process_fn is not None: elif self.process_fn is not None:
string = self.process_fn(string) string = self.process_fn(string)
return string return string
class lazy_array_loader(object): class lazy_array_loader(object):
""" """
Arguments: Arguments:
...@@ -107,17 +114,18 @@ class lazy_array_loader(object): ...@@ -107,17 +114,18 @@ class lazy_array_loader(object):
data_type2 data_type2
data_type2.len.pkl data_type2.len.pkl
""" """
def __init__(self, path, data_type='data', mem_map=False, map_fn=None): def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
lazypath = get_lazy_path(path) lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type) datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string # get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb', buffering=0) self._file = open(datapath, 'rb', buffering=0)
self.file = self._file self.file = self._file
#memory map file if necessary # memory map file if necessary
self.mem_map = mem_map self.mem_map = mem_map
if self.mem_map: if self.mem_map:
self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ) self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
lenpath = os.path.join(lazypath, data_type+'.len.pkl') lenpath = os.path.join(lazypath, data_type + '.len.pkl')
self.lens = pkl.load(open(lenpath, 'rb')) self.lens = pkl.load(open(lenpath, 'rb'))
self.ends = list(accumulate(self.lens)) self.ends = list(accumulate(self.lens))
self.dumb_ends = list(self.ends) self.dumb_ends = list(self.ends)
...@@ -149,7 +157,7 @@ class lazy_array_loader(object): ...@@ -149,7 +157,7 @@ class lazy_array_loader(object):
if index == 0: if index == 0:
start = 0 start = 0
else: else:
start = self.ends[index-1] start = self.ends[index - 1]
end = self.ends[index] end = self.ends[index]
rtn = self.file_read(start, end) rtn = self.file_read(start, end)
if self.map_fn is not None: if self.map_fn is not None:
...@@ -160,7 +168,7 @@ class lazy_array_loader(object): ...@@ -160,7 +168,7 @@ class lazy_array_loader(object):
if index.start == 0 or index.start is None: if index.start == 0 or index.start is None:
start = 0 start = 0
else: else:
start = self.ends[index.start-1] start = self.ends[index.start - 1]
stop = chr_lens[-1] stop = chr_lens[-1]
strings = self.file_read(start, stop) strings = self.file_read(start, stop)
rtn = split_strings(strings, start, chr_lens) rtn = split_strings(strings, start, chr_lens)
...@@ -181,15 +189,14 @@ class lazy_array_loader(object): ...@@ -181,15 +189,14 @@ class lazy_array_loader(object):
# read to end of file if no end point provided # read to end of file if no end point provided
if end is None: if end is None:
rtn = self.file.read() rtn = self.file.read()
#else read amount needed to reach end point # else read amount needed to reach end point
else: else:
rtn = self.file.read(end-start) rtn = self.file.read(end - start)
self.read_lock.release() self.read_lock.release()
#TODO: @raulp figure out mem map byte string bug # TODO: @raulp figure out mem map byte string bug
#if mem map'd need to decode byte string to string # if mem map'd need to decode byte string to string
rtn = rtn.decode('utf-8', 'ignore') rtn = rtn.decode('utf-8', 'ignore')
# rtn = str(rtn) # rtn = str(rtn)
if self.mem_map: if self.mem_map:
rtn = rtn.decode('unicode_escape') rtn = rtn.decode('unicode_escape')
return rtn return rtn
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from torch.utils import data from torch.utils import data
import numpy as np import numpy as np
class RandomSampler(data.sampler.Sampler): class RandomSampler(data.sampler.Sampler):
r""" r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
...@@ -63,7 +64,8 @@ class RandomSampler(data.sampler.Sampler): ...@@ -63,7 +64,8 @@ class RandomSampler(data.sampler.Sampler):
if self.epoch >= 0: if self.epoch >= 0:
g.manual_seed(self.epoch) g.manual_seed(self.epoch)
if self.replacement: if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()) return iter(torch.randint(high=n, size=(self.num_samples,),
dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist()) return iter(torch.randperm(n, generator=g).tolist())
def __len__(self): def __len__(self):
...@@ -72,12 +74,14 @@ class RandomSampler(data.sampler.Sampler): ...@@ -72,12 +74,14 @@ class RandomSampler(data.sampler.Sampler):
def set_epoch(self, epoch): def set_epoch(self, epoch):
self.epoch = epoch self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler): class DistributedBatchSampler(data.sampler.BatchSampler):
""" """
similar to normal implementation of distributed sampler, except implementation is at the similar to normal implementation of distributed sampler, except implementation is at the
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
""" """
def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False): def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
if rank == -1: if rank == -1:
...@@ -125,7 +129,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -125,7 +129,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def data_iterator(self, _iter, wrap_around=False): def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around""" """iterates through data and handles wrap around"""
for i, idx in enumerate(_iter): for i, idx in enumerate(_iter):
if i < self.wrap_around%self.batch_size: if i < self.wrap_around % self.batch_size:
continue continue
if wrap_around: if wrap_around:
self.wrap_around += 1 self.wrap_around += 1
...@@ -134,6 +138,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler): ...@@ -134,6 +138,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch): def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch""" """extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size start = self.rank * self.batch_size // self.world_size
end = (self.rank+1)*self.batch_size//self.world_size end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end] return batch[start:end]
...@@ -16,12 +16,12 @@ output_file = sys.argv[2] ...@@ -16,12 +16,12 @@ output_file = sys.argv[2]
line_seperator = "\n" line_seperator = "\n"
with open(input_file, 'r') as ifile: with open(input_file, 'r') as ifile:
with open(output_file, "w") as ofile: with open(output_file, "w") as ofile:
for doc in ifile.readlines(): for doc in ifile.readlines():
parsed = json.loads(doc) parsed = json.loads(doc)
sent_list = [] sent_list = []
for line in parsed['text'].split('\n'): for line in parsed['text'].split('\n'):
if line != '\n': if line != '\n':
sent_list.extend(nltk.tokenize.sent_tokenize(line)) sent_list.extend(nltk.tokenize.sent_tokenize(line))
parsed['text'] = line_seperator.join(sent_list) parsed['text'] = line_seperator.join(sent_list)
ofile.write(json.dumps(parsed)+'\n') ofile.write(json.dumps(parsed) + '\n')
...@@ -18,7 +18,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated ...@@ -18,7 +18,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`. under `output_dir`.
Note: This code has the potential to override files with the names Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`. train.json, val.json, test.json in `--output_dir`.
""" """
import os import os
...@@ -35,6 +35,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0], ...@@ -35,6 +35,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset') help='percentage of available data to use for val/test dataset')
args = parser.parse_args() args = parser.parse_args()
def get_lines(filepath): def get_lines(filepath):
lines = [] lines = []
with open(filepath, 'r') as f: with open(filepath, 'r') as f:
...@@ -43,6 +44,7 @@ def get_lines(filepath): ...@@ -43,6 +44,7 @@ def get_lines(filepath):
lines.append(l) lines.append(l)
return lines return lines
def get_splits(lines, line_counts): def get_splits(lines, line_counts):
all_lines = [] all_lines = []
line_idx = [] line_idx = []
...@@ -50,14 +52,14 @@ def get_splits(lines, line_counts): ...@@ -50,14 +52,14 @@ def get_splits(lines, line_counts):
for i, l in enumerate(lines): for i, l in enumerate(lines):
all_lines.extend(l) all_lines.extend(l)
line_idx.extend(list(range(len(l)))) line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l)) file_mappings.extend([i] * len(l))
indices = list(range(len(all_lines))) indices = list(range(len(all_lines)))
random.shuffle(indices) random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices] all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices] line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices] file_mappings = [file_mappings[idx] for idx in indices]
splits = [] splits = []
mappings = [] mappings = []
start = 0 start = 0
...@@ -68,10 +70,11 @@ def get_splits(lines, line_counts): ...@@ -68,10 +70,11 @@ def get_splits(lines, line_counts):
start = end start = end
return splits, mappings return splits, mappings
def format_mappings(line_idx, file_mappings): def format_mappings(line_idx, file_mappings):
lines = [] lines = []
for m, l in zip(file_mappings, line_idx): for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip()) lines.append(str(m).strip() + '\t' + str(l).strip())
return lines return lines
...@@ -85,25 +88,30 @@ def get_filepaths(filepaths, output_dir): ...@@ -85,25 +88,30 @@ def get_filepaths(filepaths, output_dir):
paths.append(os.path.join(output_dir, test_path)) paths.append(os.path.join(output_dir, test_path))
return paths return paths
def write_files(lines, mappings, filepaths): def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths): for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path) write_file(l, path)
write_mapping_file(m, path) write_mapping_file(m, path)
def write_file(lines, path): def write_file(lines, path):
print('Writing:', path) print('Writing:', path)
with open(path, 'w') as f: with open(path, 'w') as f:
for l in lines: for l in lines:
f.write(l+'\n') f.write(l + '\n')
def write_mapping_file(m, path): def write_mapping_file(m, path):
path = path+'.map' path = path + '.map'
m = [get_mapping_header()]+m m = [get_mapping_header()] + m
write_file(m, path) write_file(m, path)
def get_mapping_header(): def get_mapping_header():
return 'file\tline #' return 'file\tline #'
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
...@@ -113,16 +121,16 @@ for filepath in args.input_files: ...@@ -113,16 +121,16 @@ for filepath in args.input_files:
_lines = get_lines(filepath) _lines = get_lines(filepath)
lines.append(_lines) lines.append(_lines)
#calculate number of lines to use for each # calculate number of lines to use for each
line_counts = [len(l) for l in lines] line_counts = [len(l) for l in lines]
total_lines = sum(line_counts) total_lines = sum(line_counts)
dev_percent = args.test_percent[0] dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines) dev_lines = math.ceil(dev_percent * total_lines)
test_percent = 0 test_percent = 0
if len(args.test_percent)==2: if len(args.test_percent) == 2:
test_percent=args.test_percent[1] test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines) test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines-(test_lines+dev_lines) train_lines = total_lines - (test_lines + dev_lines)
normed_lines = [train_lines, dev_lines, test_lines] normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines] normed_lines = [int(l) for l in normed_lines]
...@@ -131,4 +139,3 @@ splits, mappings = get_splits(lines, normed_lines) ...@@ -131,4 +139,3 @@ splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir) filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths) print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths) write_files(splits, mappings, filepaths)
...@@ -3,7 +3,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated ...@@ -3,7 +3,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`. under `output_dir`.
Note: This code has the potential to override files with the names Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`. train.json, val.json, test.json in `--output_dir`.
""" """
import os import os
...@@ -20,6 +20,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0], ...@@ -20,6 +20,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset') help='percentage of available data to use for val/test dataset')
args = parser.parse_args() args = parser.parse_args()
def get_lines(filepath): def get_lines(filepath):
lines = [] lines = []
with open(filepath, 'r') as f: with open(filepath, 'r') as f:
...@@ -28,6 +29,7 @@ def get_lines(filepath): ...@@ -28,6 +29,7 @@ def get_lines(filepath):
lines.append(l) lines.append(l)
return lines return lines
def get_splits(lines, line_counts): def get_splits(lines, line_counts):
all_lines = [] all_lines = []
line_idx = [] line_idx = []
...@@ -35,14 +37,14 @@ def get_splits(lines, line_counts): ...@@ -35,14 +37,14 @@ def get_splits(lines, line_counts):
for i, l in enumerate(lines): for i, l in enumerate(lines):
all_lines.extend(l) all_lines.extend(l)
line_idx.extend(list(range(len(l)))) line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l)) file_mappings.extend([i] * len(l))
indices = list(range(len(all_lines))) indices = list(range(len(all_lines)))
random.shuffle(indices) random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices] all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices] line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices] file_mappings = [file_mappings[idx] for idx in indices]
splits = [] splits = []
mappings = [] mappings = []
start = 0 start = 0
...@@ -53,10 +55,11 @@ def get_splits(lines, line_counts): ...@@ -53,10 +55,11 @@ def get_splits(lines, line_counts):
start = end start = end
return splits, mappings return splits, mappings
def format_mappings(line_idx, file_mappings): def format_mappings(line_idx, file_mappings):
lines = [] lines = []
for m, l in zip(file_mappings, line_idx): for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip()) lines.append(str(m).strip() + '\t' + str(l).strip())
return lines return lines
...@@ -70,25 +73,30 @@ def get_filepaths(filepaths, output_dir): ...@@ -70,25 +73,30 @@ def get_filepaths(filepaths, output_dir):
paths.append(os.path.join(output_dir, test_path)) paths.append(os.path.join(output_dir, test_path))
return paths return paths
def write_files(lines, mappings, filepaths): def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths): for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path) write_file(l, path)
write_mapping_file(m, path) write_mapping_file(m, path)
def write_file(lines, path): def write_file(lines, path):
print('Writing:', path) print('Writing:', path)
with open(path, 'w') as f: with open(path, 'w') as f:
for l in lines: for l in lines:
f.write(l+'\n') f.write(l + '\n')
def write_mapping_file(m, path): def write_mapping_file(m, path):
path = path+'.map' path = path + '.map'
m = [get_mapping_header()]+m m = [get_mapping_header()] + m
write_file(m, path) write_file(m, path)
def get_mapping_header(): def get_mapping_header():
return 'file\tline #' return 'file\tline #'
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
...@@ -98,16 +106,16 @@ for filepath in args.input_files: ...@@ -98,16 +106,16 @@ for filepath in args.input_files:
_lines = get_lines(filepath) _lines = get_lines(filepath)
lines.append(_lines) lines.append(_lines)
#calculate number of lines to use for each # calculate number of lines to use for each
line_counts = [len(l) for l in lines] line_counts = [len(l) for l in lines]
total_lines = sum(line_counts) total_lines = sum(line_counts)
dev_percent = args.test_percent[0] dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines) dev_lines = math.ceil(dev_percent * total_lines)
test_percent = 0 test_percent = 0
if len(args.test_percent)==2: if len(args.test_percent) == 2:
test_percent=args.test_percent[1] test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines) test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines-(test_lines+dev_lines) train_lines = total_lines - (test_lines + dev_lines)
normed_lines = [train_lines, dev_lines, test_lines] normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines] normed_lines = [int(l) for l in normed_lines]
...@@ -116,4 +124,3 @@ splits, mappings = get_splits(lines, normed_lines) ...@@ -116,4 +124,3 @@ splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir) filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths) print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths) write_files(splits, mappings, filepaths)
...@@ -14,20 +14,22 @@ ...@@ -14,20 +14,22 @@
# limitations under the License. # limitations under the License.
"""PyTorch DataLoader for TFRecords""" """PyTorch DataLoader for TFRecords"""
import numpy as np
import torch
import queue import queue
import threading import threading
import tensorflow as tf import tensorflow as tf
tf.enable_eager_execution() tf.enable_eager_execution()
import torch
import numpy as np
class TFRecordDataLoader(object): class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False): def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq,
train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords" assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf.set_random_seed(seed) tf.set_random_seed(seed)
if isinstance(records, str): if isinstance(records, str):
records = [records] records = [records]
self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64), self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
"input_mask": tf.FixedLenFeature([max_seq_len], tf.int64), "input_mask": tf.FixedLenFeature([max_seq_len], tf.int64),
...@@ -37,7 +39,7 @@ class TFRecordDataLoader(object): ...@@ -37,7 +39,7 @@ class TFRecordDataLoader(object):
"masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32), "masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32),
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64)}) "next_sentence_labels": tf.FixedLenFeature([1], tf.int64)})
#Instantiate dataset according to original BERT implementation # Instantiate dataset according to original BERT implementation
if train: if train:
self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records)) self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records))
self.dataset = self.dataset.repeat() self.dataset = self.dataset.repeat()
...@@ -55,10 +57,12 @@ class TFRecordDataLoader(object): ...@@ -55,10 +57,12 @@ class TFRecordDataLoader(object):
self.dataset = self.dataset.repeat() self.dataset = self.dataset.repeat()
# Instantiate dataloader (do not drop remainder for eval) # Instantiate dataloader (do not drop remainder for eval)
loader_args = {'batch_size': batch_size, loader_args = {'batch_size': batch_size,
'num_parallel_batches': num_workers, 'num_parallel_batches': num_workers,
'drop_remainder': train} 'drop_remainder': train}
self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args)) self.dataloader = self.dataset.apply(
tf.contrib.data.map_and_batch(
self.record_converter, **loader_args))
self.threaded_dl = threaded_dl self.threaded_dl = threaded_dl
self.num_workers = num_workers self.num_workers = num_workers
...@@ -72,6 +76,7 @@ class TFRecordDataLoader(object): ...@@ -72,6 +76,7 @@ class TFRecordDataLoader(object):
for item in data_iter: for item in data_iter:
yield convert_tf_example_to_torch_tensors(item) yield convert_tf_example_to_torch_tensors(item)
class Record2Example(object): class Record2Example(object):
def __init__(self, feature_map): def __init__(self, feature_map):
self.feature_map = feature_map self.feature_map = feature_map
...@@ -84,23 +89,25 @@ class Record2Example(object): ...@@ -84,23 +89,25 @@ class Record2Example(object):
example[k] = tf.to_int32(v) example[k] = tf.to_int32(v)
return example return example
def convert_tf_example_to_torch_tensors(example): def convert_tf_example_to_torch_tensors(example):
item = {k: (v.numpy()) for k,v in example.items()} item = {k: (v.numpy()) for k, v in example.items()}
mask = np.zeros_like(item['input_ids']) mask = np.zeros_like(item['input_ids'])
mask_labels = np.ones_like(item['input_ids'])*-1 mask_labels = np.ones_like(item['input_ids']) * -1
for b, row in enumerate(item['masked_lm_positions'].astype(int)): for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row): for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0: if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1 mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i] mask_labels[b, idx] = item['masked_lm_ids'][b, i]
output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'], output = {'text': item['input_ids'], 'types': item['segment_ids'], 'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels} 'pad_mask': 1 - item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k,v in output.items()} return {k: torch.from_numpy(v) for k, v in output.items()}
class MultiprocessLoader(object): class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2): def __init__(self, dataloader, num_workers=2):
self.dl = dataloader self.dl = dataloader
self.queue_size = 2*num_workers self.queue_size = 2 * num_workers
def __iter__(self): def __iter__(self):
output_queue = queue.Queue(self.queue_size) output_queue = queue.Queue(self.queue_size)
...@@ -114,8 +121,9 @@ class MultiprocessLoader(object): ...@@ -114,8 +121,9 @@ class MultiprocessLoader(object):
else: else:
print(RuntimeError('TF record data loader thread exited unexpectedly')) print(RuntimeError('TF record data loader thread exited unexpectedly'))
def _multiproc_iter(dl, output_queue): def _multiproc_iter(dl, output_queue):
data_iter = iter(dl) data_iter = iter(dl)
for item in data_iter: for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item) tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True) output_queue.put(tensors, block=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment