Commit b7f1b050 authored by Neel Kant's avatar Neel Kant
Browse files

Lint whole repo

parent c99fa80c
......@@ -357,7 +357,6 @@ def _add_gpt2_args(parser):
return parser
def add_data_args_(parser):
"""Train/valid/test data arguments."""
......@@ -367,6 +366,4 @@ def add_data_args_(parser):
choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
help='Which data loader to use. Default varies by model.')
return parser
......@@ -67,7 +67,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
directory = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
mpu.get_model_parallel_rank() if mp_rank is None
else mp_rank),
'model_optim_rng.pt')
......@@ -179,7 +179,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'megatron.fp16.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except:
except BaseException:
print_rank_0('could not load the checkpoint')
sys.exit()
......@@ -190,7 +190,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
......
......@@ -47,6 +47,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
......@@ -113,7 +114,6 @@ class BertDataset(Dataset):
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix,
......@@ -133,11 +133,9 @@ class BertDataset(Dataset):
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx]
......@@ -148,7 +146,7 @@ class BertDataset(Dataset):
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
......@@ -192,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size):
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split/splits_sum for split in splits]
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
......@@ -254,7 +252,7 @@ def get_samples_mapping_(indexed_dataset,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length-3, # account for added tokens
max_seq_length - 3, # account for added tokens
short_seq_prob,
seed,
verbose)
......
......@@ -42,6 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
......@@ -54,7 +55,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index+1],
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset,
......@@ -102,21 +103,19 @@ class GPT2Dataset(torch.utils.data.Dataset):
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self):
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx+1][0]
doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx+1][1]
offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
......@@ -127,18 +126,17 @@ class GPT2Dataset(torch.utils.data.Dataset):
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f+1, doc_index_l):
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l],
length=offset_l+1))
length=offset_l + 1))
sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples, seq_length, seed):
"""Build doc-idx, sample-idx, and shuffle-idx.
......@@ -185,7 +183,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
......@@ -194,7 +192,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng)
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
......
......@@ -20,6 +20,7 @@ import numpy as np
import torch
from megatron import print_rank_0
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
......@@ -109,13 +110,15 @@ def index_file_path(prefix_path):
def data_file_path(prefix_path):
return prefix_path + '.bin'
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i+1)
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
......@@ -155,7 +158,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if self.data_file:
self.data_file.close()
#@lru_cache(maxsize=8)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
......@@ -235,7 +238,7 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close()
self.data_file = None
#@lru_cache(maxsize=8)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
......@@ -399,13 +402,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
......@@ -464,7 +472,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return len(self._index)
#@lru_cache(maxsize=8)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
......
......@@ -81,6 +81,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
......@@ -120,7 +121,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around%self.batch_size:
if i < self.wrap_around % self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
......@@ -129,6 +130,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size
end = (self.rank+1)*self.batch_size//self.world_size
start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end]
......@@ -2,6 +2,8 @@
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse
import os
import sys
......@@ -11,8 +13,6 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
......@@ -23,12 +23,12 @@ def test_indexed_dataset(args):
if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds)))
if args.count > len(ds.doc_idx)-1:
args.count = len(ds.doc_idx)-1
if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx) - 1
for i in range(args.count):
start = ds.doc_idx[i]
end = ds.doc_idx[i+1]
end = ds.doc_idx[i + 1]
ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
......@@ -39,6 +39,7 @@ def test_indexed_dataset(args):
print(text)
print("---")
def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
......@@ -46,19 +47,19 @@ def test_indexed_dataset_get(args):
print(f"size: {size}")
full = ds.get(0)
print(full)
#print(tokenizer.detokenize(full.data.tolist()))
# print(tokenizer.detokenize(full.data.tolist()))
print("---")
end = ds.get(0, offset=size-10)
end = ds.get(0, offset=size - 10)
print(end)
#print(tokenizer.detokenize(end.data.tolist()))
# print(tokenizer.detokenize(end.data.tolist()))
start = ds.get(0, length=10)
print(start)
#print(tokenizer.detokenize(start.data.tolist()))
# print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8)
print(part)
#print(tokenizer.detokenize(part.data.tolist()))
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
......@@ -77,6 +78,7 @@ def test_indexed_dataset_get(args):
# if i >= args.count-1:
# exit()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files')
......@@ -118,5 +120,6 @@ def main():
# test_albert_dataset(args)
test_indexed_dataset_get(args)
if __name__ == "__main__":
main()
......@@ -28,21 +28,24 @@ TRAIN_DATA = 0
VAL_DATA = 1
TEST_DATA = 2
def should_split(split):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split)/sum(split) != 1.
return max(split) / sum(split) != 1.
def get_ext(path):
"""gets path extension"""
return os.path.splitext(path)[1]
def get_dataset(path, **kwargs):
"""gets dataset object based on keyword args and file at `path`"""
if supported_corpus(path):
......@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs):
elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs)
else:
raise NotImplementedError('data file type %s is not supported'%(ext))
raise NotImplementedError('data file type %s is not supported' % (ext))
return text
def supported_corpus(corpus_name):
"""checks if corpus name is defined in `corpora.py`"""
return corpus_name in corpora.NAMED_CORPORA
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str):
......@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if non_binary_cols is not None:
# multilabel dataset support (only for csvs)
label_key = non_binary_cols
def get_dataset_from_path(path_):
if lazy:
# get lazily loaded dataset
......@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
......@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
else:
# get dataset
text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
return text
# get one or multiple datasets and concatenate
if isinstance(path, str):
......@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds = ConcatDataset(datasets)
# make tokenizer for dataset
if tokenizer is None:
tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, **kwargs)
tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, **kwargs)
ds_type = ''
if 'ds_type' in kwargs:
......@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else:
......
......@@ -21,6 +21,7 @@ import torch
from megatron import data_utils
from megatron import mpu
class DataConfig:
def __init__(self, defaults={}):
......@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args):
shuffle = args.shuffle
if shuffle:
sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters)
sampler = data_utils.samplers.RandomSampler(
dataset, replacement=True, num_samples=batch_size * args.train_iters)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
world_size = torch.distributed.get_world_size(
......@@ -204,6 +206,7 @@ def make_loaders(args):
return (train, valid, test), tokenizer
def get_split(args):
"""
Get dataset splits from comma separated string list
......@@ -217,7 +220,7 @@ def get_split(args):
splits = [float(args.split)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1-split_total)
splits.append(1 - split_total)
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
......@@ -226,10 +229,10 @@ def get_split(args):
if args.test_data is not None:
splits[2] = 0.
final_sum = sum(splits)
return [s/final_sum for s in splits]
return [s / final_sum for s in splits]
def configure_data():
def configure_data():
"""add cmdline flags for configuring datasets"""
# These are options that are used by data_utils, but are either
# deprecated or not meant to be exposed to the command line user.
......
......@@ -16,43 +16,46 @@
from .datasets import json_dataset, csv_dataset
import os
class wikipedia(json_dataset):
"""
dataset for wikipedia with arguments configured for convenience
command line usage: `--train-data wikipedia`
"""
PATH = 'data/wikipedia/wikidump_lines.json'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(wikipedia.PATH), \
wikipedia.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(wikipedia, self).__init__(wikipedia.PATH, **kwargs)
"""
dataset for wikipedia with arguments configured for convenience
command line usage: `--train-data wikipedia`
"""
PATH = 'data/wikipedia/wikidump_lines.json'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(wikipedia.PATH), \
wikipedia.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(wikipedia, self).__init__(wikipedia.PATH, **kwargs)
class webtext(json_dataset):
"""
dataset for webtext with arguments configured for convenience
command line usage: `--train-data webtext`
"""
PATH = 'data/webtext/data.json'
assert_str = "make sure to set PATH for webtext data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(webtext.PATH), \
webtext.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(webtext, self).__init__(webtext.PATH, **kwargs)
"""
dataset for webtext with arguments configured for convenience
command line usage: `--train-data webtext`
"""
PATH = 'data/webtext/data.json'
assert_str = "make sure to set PATH for webtext data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(webtext.PATH), \
webtext.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(webtext, self).__init__(webtext.PATH, **kwargs)
NAMED_CORPORA = {
'wikipedia': wikipedia,
'webtext': webtext,
'wikipedia': wikipedia,
'webtext': webtext,
}
......@@ -34,6 +34,7 @@ from nltk import tokenize
from .lazy_loader import lazy_array_loader, exists_lazy, make_lazy
from .tokenization import Tokenization
class ConcatDataset(data.Dataset):
"""
Dataset to concatenate multiple datasets.
......@@ -57,7 +58,8 @@ class ConcatDataset(data.Dataset):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.is_lazy = sum([isinstance(ds, lazy_array_loader) for ds in self.datasets]) == len(self.datasets)
self.is_lazy = sum([isinstance(ds, lazy_array_loader)
for ds in self.datasets]) == len(self.datasets)
self.cumulative_sizes = self.cumsum(self.datasets)
self._X = None
self._Y = None
......@@ -90,7 +92,8 @@ class ConcatDataset(data.Dataset):
self._lens.extend(data.lens)
else:
for data in self.datasets:
self._lens.extend([len(d['text']) if isinstance(d, dict) else len(d) for d in data])
self._lens.extend([len(d['text']) if isinstance(
d, dict) else len(d) for d in data])
return self._lens
@property
......@@ -116,6 +119,7 @@ class ConcatDataset(data.Dataset):
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
class SplitDataset(data.Dataset):
"""
Dataset wrapper to access a subset of another dataset.
......@@ -126,6 +130,7 @@ class SplitDataset(data.Dataset):
ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset
"""
def __init__(self, ds, split_inds, **kwargs):
self.split_inds = list(split_inds)
self.wrapped_data = ds
......@@ -163,7 +168,8 @@ class SplitDataset(data.Dataset):
for idx in self.split_inds:
yield self.wrapped_data[idx]
def split_ds(ds, split=[.8,.2,.0], shuffle=True):
def split_ds(ds, split=[.8, .2, .0], shuffle=True):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
......@@ -184,18 +190,19 @@ def split_ds(ds, split=[.8,.2,.0], shuffle=True):
np.random.shuffle(inds)
start_idx = 0
residual_idx = 0
rtn_ds = [None]*len(split)
rtn_ds = [None] * len(split)
for i, f in enumerate(split):
if f != 0:
proportion = ds_len*split[i]
proportion = ds_len * split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
split_inds = inds[start_idx:start_idx+max(split_, 1)]
split_inds = inds[start_idx:start_idx + max(split_, 1)]
rtn_ds[i] = SplitDataset(ds, split_inds)
start_idx += split_
residual_idx %= 1
return rtn_ds
class csv_dataset(data.Dataset):
"""
Class for loading datasets from csv files.
......@@ -214,9 +221,10 @@ class csv_dataset(data.Dataset):
X (list): all strings from the csv file
Y (np.ndarray): labels to train with
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',',
binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label',
**kwargs):
binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label',
**kwargs):
self.is_lazy = False
self.preprocess_fn = preprocess_fn
self.SetTokenizer(tokenizer)
......@@ -229,7 +237,6 @@ class csv_dataset(data.Dataset):
if '.tsv' in self.path:
self.delim = '\t'
self.X = []
self.Y = []
try:
......@@ -239,7 +246,7 @@ class csv_dataset(data.Dataset):
else:
cols += [label_key]
data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1')
except:
except BaseException:
data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1')
data = data.dropna(axis=0)
......@@ -248,7 +255,7 @@ class csv_dataset(data.Dataset):
try:
self.Y = data[label_key].values
except Exception as e:
self.Y = np.ones(len(self.X))*-1
self.Y = np.ones(len(self.X)) * -1
if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent)
......@@ -295,23 +302,25 @@ class csv_dataset(data.Dataset):
write the metrics, text, and labels to a csv file
"""
if path is None:
path = self.path+'.results'
path = self.path + '.results'
print('generating csv at ' + path)
with open(path, 'w') as csvfile:
c = csv.writer(csvfile, delimiter=self.delim)
if writer_gen is not None:
#if first item of generator is a header of what the metrics mean then write header to csv file
# if first item of generator is a header of what the metrics mean then
# write header to csv file
if not skip_header:
header = (self.label_key,)+tuple(next(writer_gen))+(self.text_key,)
header = (self.label_key,) + tuple(next(writer_gen)) + (self.text_key,)
c.writerow(header)
for i, row in enumerate(writer_gen):
row = (self.Y[i],)+tuple(row)+(self.X[i],)
row = (self.Y[i],) + tuple(row) + (self.X[i],)
c.writerow(row)
else:
c.writerow([self.label_key, self.text_key])
for row in zip(self.Y, self.X):
c.writerow(row)
class json_dataset(data.Dataset):
"""
Class for loading datasets from a json dump.
......@@ -327,8 +336,9 @@ class json_dataset(data.Dataset):
all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it)
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False,
text_key='sentence', label_key='label', loose_json=False, **kwargs):
text_key='sentence', label_key='label', loose_json=False, **kwargs):
self.is_lazy = False
self.preprocess_fn = preprocess_fn
self.path = path
......@@ -389,24 +399,25 @@ class json_dataset(data.Dataset):
write the metrics, text, and labels to a json file
"""
if path is None:
path = self.path+'.results'
path = self.path + '.results'
jsons = []
if writer_gen is not None:
#if first item of generator is a header of what the metrics mean then write header to csv file
# if first item of generator is a header of what the metrics mean then
# write header to csv file
def gen_helper():
keys = {}
keys[0] = self.label_key
if not skip_header:
for idx, k in enumerate(tuple(next(writer_gen))):
keys[idx+1] = k
keys[idx + 1] = k
for i, row in enumerate(writer_gen):
if i == 0 and skip_header:
for idx, _ in enumerate(row):
keys[idx+1] = 'metric_%d'%(idx,)
keys[idx + 1] = 'metric_%d' % (idx,)
j = {}
for idx, v in enumerate((self.Y[i],)+tuple(row)):
for idx, v in enumerate((self.Y[i],) + tuple(row)):
k = keys[idx]
j[k] = v
yield j
......@@ -453,6 +464,7 @@ class json_dataset(data.Dataset):
j[self.label_key] = -1
yield j
class GPT2Dataset(data.Dataset):
def __init__(self, ds,
......@@ -503,7 +515,7 @@ class GPT2Dataset(data.Dataset):
def __getitem__(self, idx):
# init rng
rng = random.Random(idx)
rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get possibly weighted random index from dataset
data_idx = self.get_weighted_samples(rng)
......@@ -538,10 +550,10 @@ class GPT2Dataset(data.Dataset):
else:
data_idx = (data_idx + 1) % self.ds_len
tokens += self.getidx(data_idx)
tokens = tokens[:(self.max_seq_len+1)]
tokens = tokens[:(self.max_seq_len + 1)]
tokens = self.pad_seq(tokens)
return {'text': np.array(tokens),}
return {'text': np.array(tokens), }
def getidx(self, data_idx):
data = self.ds[data_idx]
......@@ -556,7 +568,7 @@ class GPT2Dataset(data.Dataset):
def pad_seq(self, seq):
total_tokens = self.max_seq_len + 1
num_pad_tokens = max(0, total_tokens - len(seq))
seq += [self.tokenizer.get_command('pad').Id]*(num_pad_tokens)
seq += [self.tokenizer.get_command('pad').Id] * (num_pad_tokens)
return seq
def contains_sentence_end(self, tok):
......@@ -569,6 +581,7 @@ class GPT2Dataset(data.Dataset):
return True
return False
class bert_sentencepair_dataset(data.Dataset):
"""
Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
......@@ -581,7 +594,9 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None,
short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
......@@ -590,12 +605,12 @@ class bert_sentencepair_dataset(data.Dataset):
self.max_seq_len = max_seq_len
self.mask_lm_prob = mask_lm_prob
if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
max_preds_per_seq = math.ceil(max_seq_len * mask_lm_prob / 10) * 10
self.max_preds_per_seq = max_preds_per_seq
self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size
if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1)
self.dataset_size = self.ds_len * (self.ds_len - 1)
self.presplit_sentences = presplit_sentences
if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk")
......@@ -607,7 +622,8 @@ class bert_sentencepair_dataset(data.Dataset):
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens)
else:
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds])
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d)
for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
......@@ -626,7 +642,7 @@ class bert_sentencepair_dataset(data.Dataset):
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get seq length
target_seq_length = self.max_seq_len
short_seq = False
......@@ -639,15 +655,25 @@ class bert_sentencepair_dataset(data.Dataset):
lena = 0
lenb = 0
while (is_random_next is None) or (lena < 1) or (lenb < 1):
tokensa, tokensb, is_random_next = self.create_random_sentencepair(target_seq_length, rng, np_rng)
tokensa, tokensb, is_random_next = self.create_random_sentencepair(
target_seq_length, rng, np_rng)
lena = len(tokensa[0])
lenb = len(tokensb[0])
# truncate sentence pair to max_seq_len
tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng)
# join sentence pair, mask, and pad
tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng)
sample = {'text': np.array(tokens[0]), 'types': np.array(tokens[1]), 'is_random': int(is_random_next), 'mask': np.array(mask), 'mask_labels': np.array(mask_labels), 'pad_mask': np.array(pad_mask)}
tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(
tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng)
sample = {
'text': np.array(
tokens[0]),
'types': np.array(
tokens[1]),
'is_random': int(is_random_next),
'mask': np.array(mask),
'mask_labels': np.array(mask_labels),
'pad_mask': np.array(pad_mask)}
return sample
def sentence_split(self, document):
......@@ -665,7 +691,7 @@ class bert_sentencepair_dataset(data.Dataset):
"""tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens)
token_types = [self.tokenizer.get_type(str_type).Id] * len(tokens)
return tokens, token_types
def get_doc(self, idx):
......@@ -694,21 +720,22 @@ class bert_sentencepair_dataset(data.Dataset):
# doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
doc_a_idx = self.get_weighted_samples(np_rng)
else:
doc_a_idx = rng.randint(0, self.ds_len-1)
doc_a_idx = rng.randint(0, self.ds_len - 1)
doc_a = self.sentence_split(self.get_doc(doc_a_idx))
if not doc_a:
doc_a = None
random_start_a = rng.randint(0, len(doc_a)-1)
random_start_a = rng.randint(0, len(doc_a) - 1)
while random_start_a < len(doc_a):
sentence = doc_a[random_start_a]
sentence, sentence_types = self.sentence_tokenize(sentence, 0, random_start_a == 0, random_start_a == len(doc_a))
sentence, sentence_types = self.sentence_tokenize(
sentence, 0, random_start_a == 0, random_start_a == len(doc_a))
curr_strs.append(sentence)
curr_str_types.append(sentence_types)
curr_len += len(sentence)
if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length:
break
random_start_a = (random_start_a+1)
random_start_a = (random_start_a + 1)
if curr_strs:
num_a = 1
......@@ -738,16 +765,17 @@ class bert_sentencepair_dataset(data.Dataset):
if not doc_b:
doc_b = None
random_start_b = rng.randint(0, len(doc_b)-1)
random_start_b = rng.randint(0, len(doc_b) - 1)
while random_start_b < len(doc_b):
sentence_b = doc_b[random_start_b]
new_b_tokens, new_b_types = self.sentence_tokenize(sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b))
new_b_tokens, new_b_types = self.sentence_tokenize(
sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b))
b_len += len(new_b_tokens)
tokens_b.extend(new_b_tokens)
token_types_b.extend(new_b_types)
if len(tokens_b) >= target_b_length:
break
random_start_b = (random_start_b+1)
random_start_b = (random_start_b + 1)
else:
is_random_next = False
for j in range(num_a, len(curr_strs)):
......@@ -812,13 +840,15 @@ class bert_sentencepair_dataset(data.Dataset):
def pad_seq(self, seq):
"""helper function to pad sequence pair"""
num_pad = max(0, self.max_seq_len - len(seq))
pad_mask = [0] * len(seq) + [1] * num_pad
pad_mask = [0] * len(seq) + [1] * num_pad
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]]
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command(
'sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + \
[token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
......@@ -833,7 +863,7 @@ class bert_sentencepair_dataset(data.Dataset):
len_a = len(tokens_a)
len_b = len(tokens_b)
cand_indices = [idx+1 for idx in range(len_a)] + [idx+2+len_a for idx in range(len_b)]
cand_indices = [idx + 1 for idx in range(len_a)] + [idx + 2 + len_a for idx in range(len_b)]
rng.shuffle(cand_indices)
......
......@@ -169,7 +169,7 @@ def http_get(url, temp_file):
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
......
......@@ -22,11 +22,13 @@ from itertools import accumulate
import torch
from torch.multiprocessing import Lock
def get_lazy_path(path):
"""
Gets directory path where lazy files are stored.
"""
return os.path.splitext(path)[0]+'.lazy'
return os.path.splitext(path)[0] + '.lazy'
def exists_lazy(path, data_type='data'):
"""
......@@ -37,10 +39,11 @@ def exists_lazy(path, data_type='data'):
contents = os.listdir(get_lazy_path(path))
if data_type not in contents:
return False
if data_type+'.len.pkl' not in contents:
if data_type + '.len.pkl' not in contents:
return False
return True
def make_lazy(path, strs, data_type='data'):
"""
Make lazy version of `data_type` field of the file. Byte offsets
......@@ -50,7 +53,7 @@ def make_lazy(path, strs, data_type='data'):
if not os.path.exists(lazypath):
os.makedirs(lazypath)
datapath = os.path.join(lazypath, data_type)
lenpath = os.path.join(lazypath, data_type+'.len.pkl')
lenpath = os.path.join(lazypath, data_type + '.len.pkl')
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
with open(datapath, 'wb') as f:
str_lens = []
......@@ -67,28 +70,32 @@ def make_lazy(path, strs, data_type='data'):
while not os.path.exists(lenpath):
time.sleep(1)
def split_strings(strings, start, chr_lens):
"""
Split strings based on string lengths and given start.
"""
return [strings[i-start:j-start] for i, j in zip([start]+chr_lens[:-1], chr_lens)]
return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)]
class ProcessorTokenizer:
"""
callable class that runs a preprocessing, as well as tokenization step,
on input text.
"""
def __init__(self, tokenizer, process_fn=None):
self.tokenizer = tokenizer
self.process_fn = process_fn
def __call__(self, string):
if self.tokenizer is not None:
string = self.tokenizer(string, process_fn=self.process_fn)
string = self.tokenizer(string, process_fn=self.process_fn)
elif self.process_fn is not None:
string = self.process_fn(string)
string = self.process_fn(string)
return string
class lazy_array_loader(object):
"""
Arguments:
......@@ -107,17 +114,18 @@ class lazy_array_loader(object):
data_type2
data_type2.len.pkl
"""
def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
#get file where array entries are concatenated into one big string
# get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb', buffering=0)
self.file = self._file
#memory map file if necessary
# memory map file if necessary
self.mem_map = mem_map
if self.mem_map:
self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
lenpath = os.path.join(lazypath, data_type+'.len.pkl')
lenpath = os.path.join(lazypath, data_type + '.len.pkl')
self.lens = pkl.load(open(lenpath, 'rb'))
self.ends = list(accumulate(self.lens))
self.dumb_ends = list(self.ends)
......@@ -149,7 +157,7 @@ class lazy_array_loader(object):
if index == 0:
start = 0
else:
start = self.ends[index-1]
start = self.ends[index - 1]
end = self.ends[index]
rtn = self.file_read(start, end)
if self.map_fn is not None:
......@@ -160,7 +168,7 @@ class lazy_array_loader(object):
if index.start == 0 or index.start is None:
start = 0
else:
start = self.ends[index.start-1]
start = self.ends[index.start - 1]
stop = chr_lens[-1]
strings = self.file_read(start, stop)
rtn = split_strings(strings, start, chr_lens)
......@@ -181,15 +189,14 @@ class lazy_array_loader(object):
# read to end of file if no end point provided
if end is None:
rtn = self.file.read()
#else read amount needed to reach end point
# else read amount needed to reach end point
else:
rtn = self.file.read(end-start)
rtn = self.file.read(end - start)
self.read_lock.release()
#TODO: @raulp figure out mem map byte string bug
#if mem map'd need to decode byte string to string
# TODO: @raulp figure out mem map byte string bug
# if mem map'd need to decode byte string to string
rtn = rtn.decode('utf-8', 'ignore')
# rtn = str(rtn)
if self.mem_map:
rtn = rtn.decode('unicode_escape')
return rtn
......@@ -21,6 +21,7 @@ import torch
from torch.utils import data
import numpy as np
class RandomSampler(data.sampler.Sampler):
r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
......@@ -63,7 +64,8 @@ class RandomSampler(data.sampler.Sampler):
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist())
return iter(torch.randint(high=n, size=(self.num_samples,),
dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
......@@ -72,12 +74,14 @@ class RandomSampler(data.sampler.Sampler):
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler):
"""
similar to normal implementation of distributed sampler, except implementation is at the
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
if rank == -1:
......@@ -125,7 +129,7 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around%self.batch_size:
if i < self.wrap_around % self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
......@@ -134,6 +138,6 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size
end = (self.rank+1)*self.batch_size//self.world_size
start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end]
......@@ -16,12 +16,12 @@ output_file = sys.argv[2]
line_seperator = "\n"
with open(input_file, 'r') as ifile:
with open(output_file, "w") as ofile:
for doc in ifile.readlines():
parsed = json.loads(doc)
sent_list = []
for line in parsed['text'].split('\n'):
if line != '\n':
sent_list.extend(nltk.tokenize.sent_tokenize(line))
parsed['text'] = line_seperator.join(sent_list)
ofile.write(json.dumps(parsed)+'\n')
with open(output_file, "w") as ofile:
for doc in ifile.readlines():
parsed = json.loads(doc)
sent_list = []
for line in parsed['text'].split('\n'):
if line != '\n':
sent_list.extend(nltk.tokenize.sent_tokenize(line))
parsed['text'] = line_seperator.join(sent_list)
ofile.write(json.dumps(parsed) + '\n')
......@@ -18,7 +18,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
......@@ -35,6 +35,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
......@@ -43,6 +44,7 @@ def get_lines(filepath):
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
......@@ -50,14 +52,14 @@ def get_splits(lines, line_counts):
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
file_mappings.extend([i] * len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
......@@ -68,10 +70,11 @@ def get_splits(lines, line_counts):
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
lines.append(str(m).strip() + '\t' + str(l).strip())
return lines
......@@ -85,25 +88,30 @@ def get_filepaths(filepaths, output_dir):
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
f.write(l + '\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
path = path + '.map'
m = [get_mapping_header()] + m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
......@@ -113,16 +121,16 @@ for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
# calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
dev_lines = math.ceil(dev_percent * total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
if len(args.test_percent) == 2:
test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines - (test_lines + dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
......@@ -131,4 +139,3 @@ splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
......@@ -3,7 +3,7 @@ Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
......@@ -20,6 +20,7 @@ parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
......@@ -28,6 +29,7 @@ def get_lines(filepath):
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
......@@ -35,14 +37,14 @@ def get_splits(lines, line_counts):
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
file_mappings.extend([i] * len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
......@@ -53,10 +55,11 @@ def get_splits(lines, line_counts):
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
lines.append(str(m).strip() + '\t' + str(l).strip())
return lines
......@@ -70,25 +73,30 @@ def get_filepaths(filepaths, output_dir):
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
f.write(l + '\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
path = path + '.map'
m = [get_mapping_header()] + m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
......@@ -98,16 +106,16 @@ for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
# calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
dev_lines = math.ceil(dev_percent * total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
if len(args.test_percent) == 2:
test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines - (test_lines + dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
......@@ -116,4 +124,3 @@ splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
......@@ -14,20 +14,22 @@
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import numpy as np
import torch
import queue
import threading
import tensorflow as tf
tf.enable_eager_execution()
import torch
import numpy as np
class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq,
train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf.set_random_seed(seed)
if isinstance(records, str):
records = [records]
records = [records]
self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
"input_mask": tf.FixedLenFeature([max_seq_len], tf.int64),
......@@ -37,7 +39,7 @@ class TFRecordDataLoader(object):
"masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32),
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64)})
#Instantiate dataset according to original BERT implementation
# Instantiate dataset according to original BERT implementation
if train:
self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records))
self.dataset = self.dataset.repeat()
......@@ -55,10 +57,12 @@ class TFRecordDataLoader(object):
self.dataset = self.dataset.repeat()
# Instantiate dataloader (do not drop remainder for eval)
loader_args = {'batch_size': batch_size,
loader_args = {'batch_size': batch_size,
'num_parallel_batches': num_workers,
'drop_remainder': train}
self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args))
self.dataloader = self.dataset.apply(
tf.contrib.data.map_and_batch(
self.record_converter, **loader_args))
self.threaded_dl = threaded_dl
self.num_workers = num_workers
......@@ -72,6 +76,7 @@ class TFRecordDataLoader(object):
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)
class Record2Example(object):
def __init__(self, feature_map):
self.feature_map = feature_map
......@@ -84,23 +89,25 @@ class Record2Example(object):
example[k] = tf.to_int32(v)
return example
def convert_tf_example_to_torch_tensors(example):
item = {k: (v.numpy()) for k,v in example.items()}
item = {k: (v.numpy()) for k, v in example.items()}
mask = np.zeros_like(item['input_ids'])
mask_labels = np.ones_like(item['input_ids'])*-1
mask_labels = np.ones_like(item['input_ids']) * -1
for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i]
output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k,v in output.items()}
output = {'text': item['input_ids'], 'types': item['segment_ids'], 'is_random': item['next_sentence_labels'],
'pad_mask': 1 - item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k, v in output.items()}
class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2):
self.dl = dataloader
self.queue_size = 2*num_workers
self.queue_size = 2 * num_workers
def __iter__(self):
output_queue = queue.Queue(self.queue_size)
......@@ -114,8 +121,9 @@ class MultiprocessLoader(object):
else:
print(RuntimeError('TF record data loader thread exited unexpectedly'))
def _multiproc_iter(dl, output_queue):
data_iter = iter(dl)
for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True)
\ No newline at end of file
output_queue.put(tensors, block=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment