Commit 316d3f90 authored by Pan,Huiwen's avatar Pan,Huiwen
Browse files

增加ds框架测试模型

parent aebde649
import itertools
import os
import pickle
import shutil
import numpy as np
import torch
from megatron import get_args
from megatron import mpu
def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, load_from_path=True, rank=None):
self.embed_data = dict()
self.meta_data = dict()
if block_data_path is None:
args = get_args()
block_data_path = args.block_data_path
rank = args.rank
self.block_data_path = block_data_path
self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self):
return {
'embed_data': self.embed_data,
'meta_data': self.meta_data,
}
def clear(self):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
self.embed_data = dict()
def load_from_file(self):
"""Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data']
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
def save_shard(self):
"""Save the block data that was created this in this process"""
if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()"""
shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False
for fname in os.listdir(self.temp_dir_name):
shard_rank = int(os.path.splitext(fname)[0])
if shard_rank == self.rank:
seen_own_shard = True
continue
with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
data = pickle.load(f)
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there is no overlap
self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard
# save the consolidated shards and remove temporary directory
with open(self.block_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
print("Finished merging {} shards for a total of {} embeds".format(
len(shard_names), len(self.embed_data)), flush=True)
class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def __init__(self, embed_size, block_data=None, use_gpu=False):
self.embed_size = embed_size
self.block_data = block_data
self.use_gpu = use_gpu
self.id_map = dict()
self.block_mips_index = None
self._set_block_index()
def _set_block_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against"""
try:
import faiss
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if self.use_gpu:
# create resources and config for GpuIndex
res = faiss.StandardGpuResources()
config = faiss.GpuIndexFlatConfig()
config.device = torch.cuda.current_device()
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None:
self.add_block_embed_data(self.block_data)
def reset_index(self):
"""Delete existing index and create anew"""
del self.block_mips_index
# reset the block data so that _set_block_index will reload it as well
if self.block_data is not None:
block_data_path = self.block_data.block_data_path
del self.block_data
self.block_data = BlockData(block_data_path)
self._set_block_index()
def add_block_embed_data(self, all_block_data):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr = np.float32(np.array(block_embeds))
block_indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
# we no longer need the embedding data since it's in the index now
all_block_data.clear()
if self.use_gpu:
self.block_mips_index.add(block_embeds_arr)
else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batch samplers that work with either random or sequential data samplers."""
import torch
from torch.utils import data
class RandomSampler(data.sampler.Sampler):
"""Based off of pytorch RandomSampler and DistributedSampler. Essentially
a RandomSampler, but this class lets the user set an epoch like
DistributedSampler Samples elements randomly. If without replacement, then
sample from a shuffled dataset. If with replacement, then user can
specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``,
default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not "
"be specified, since a random permute will be "
"performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(
self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,),
dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler):
"""Similar to normal implementation of distributed sampler, except
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler.
The `interleave` argument specifies how to distribute a batch. A value
of True combined with the above random sampler is equivalent to pytorch's
torch.utils.data.distributed.DistributedSampler.
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False, interleave=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last)
if rank == -1:
assert False, 'should not be here'
rank = torch.distributed.get_rank()
self.rank = rank
self.world_size = world_size
self.sampler.wrap_around = 0
self.wrap_around = 0
self.wrap_last = wrap_last
self.start_iter = 0
self.interleave = interleave
def __iter__(self):
batch = []
i = 0
for idx in self.data_iterator(self.sampler, wrap_around=False):
batch.append(idx)
if len(batch) == self.batch_size:
tbatch = self._batch(batch)
if i >= self.start_iter:
yield tbatch
self.start_iter = 0
i += 1
batch = []
batch_len = len(batch)
if batch_len > 0 and not self.drop_last:
if self.wrap_last:
self.sampler.wrap_around -= (self.batch_size)
self.wrap_around += (len(batch))
self.wrap_around %= self.batch_size
yield self._batch(batch)
if self.wrap_last:
self.sampler.wrap_around += self.batch_size
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around % self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
self.wrap_around %= self.batch_size
yield idx
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
if self.interleave:
return batch[self.rank:self.batch_size:self.world_size]
start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end]
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse
import os
import sys
import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))
def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
print(len(ds.doc_idx))
print(len(ds))
print(ds.doc_idx[-1])
if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds)))
if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx) - 1
for i in range(args.count):
start = ds.doc_idx[i]
end = ds.doc_idx[i + 1]
ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids:
assert len(s) > 0
l = s.data.tolist()
text = tokenizer.detokenize(l)
print(text)
print("---")
def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
size = ds.sizes[0]
print(f"size: {size}")
full = ds.get(0)
print(full)
# print(tokenizer.detokenize(full.data.tolist()))
print("---")
end = ds.get(0, offset=size - 10)
print(end)
# print(tokenizer.detokenize(end.data.tolist()))
start = ds.get(0, length=10)
print(start)
# print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8)
print(part)
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for')
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
help='probability of masking tokens')
parser.add_argument('--seq-length', type=int, default=512,
help='maximum sequence length')
parser.add_argument('--short-seq-prob', type=float, default=0.1,
help='probability of creating a short sequence')
parser.add_argument('--seed', type=int, default=1234,
help='random seed')
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
# test_albert_dataset(args)
test_indexed_dataset_get(args)
if __name__ == "__main__":
main()
#!/bin/bash
IMPL=cached
python ../preprocess_data.py \
--input test_samples.json \
--vocab vocab.txt \
--dataset-impl ${IMPL} \
--output-prefix test_samples_${IMPL} \
--workers 1 \
--log-interval 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for creating datasets"""
import os
import math
import torch
from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer
from . import corpora
TRAIN_DATA = 0
VAL_DATA = 1
TEST_DATA = 2
def should_split(split):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split) / sum(split) != 1.
def get_ext(path):
"""gets path extension"""
return os.path.splitext(path)[1]
def get_dataset(path, **kwargs):
"""gets dataset object based on keyword args and file at `path`"""
if supported_corpus(path):
return corpora.NAMED_CORPORA[path](**kwargs)
ext = get_ext(path)
if '.json' in ext:
text = json_dataset(path, **kwargs)
elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs)
else:
raise NotImplementedError('data file type %s is not supported' % (ext))
return text
def supported_corpus(corpus_name):
"""checks if corpus name is defined in `corpora.py`"""
return corpus_name in corpora.NAMED_CORPORA
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str):
process_fn = eval(process_fn)
if non_binary_cols is not None:
# multilabel dataset support (only for csvs)
label_key = non_binary_cols
def get_dataset_from_path(path_):
if lazy:
# get lazily loaded dataset
named_corpora = False
if supported_corpus(path_):
named_corpora = True
name = path_
path_ = corpora.NAMED_CORPORA[path_].PATH
if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=parallel_group)
assert counts[0].item() == torch.distributed.get_world_size(
group=parallel_group)
text = lazy_array_loader(path_, data_type='data', map_fn=process_fn)
else:
# get dataset
text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
return text
# get one or multiple datasets and concatenate
if isinstance(path, str):
path = [path]
datasets = [get_dataset_from_path(p) for p in path]
if len(datasets) == 1:
ds = datasets[0]
else:
ds = ConcatDataset(datasets)
# make tokenizer for dataset
if tokenizer is None:
tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, **kwargs)
ds_type = ''
if 'ds_type' in kwargs:
ds_type = kwargs['ds_type']
ds.SetTokenizer(tokenizer)
# Split dataset into train/val/test (and wrap bert dataset)
if should_split(split):
ds = split_ds(ds, split)
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else:
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset
ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2':
ds = GPT2Dataset(ds, max_seq_len=seq_length)
return ds, tokenizer
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""parses arguments and preps data loader"""
import copy
import torch
from megatron import data_utils
from megatron import mpu
class DataConfig:
def __init__(self, defaults={}):
super(DataConfig, self).__init__()
self.defaults = defaults
def apply(self, args):
if torch.distributed.get_rank() == 0:
print('configuring data')
self.apply_defaults(args)
return make_loaders(args)
def set_defaults(self, **kwargs):
for k, v in kwargs.items():
self.defaults[k] = v
def apply_defaults(self, args):
for k, v in self.defaults.items():
k = k.replace('-', '_')
if not hasattr(args, k):
setattr(args, k, v)
def make_data_loader(dataset, batch_size, args):
shuffle = args.shuffle
if shuffle:
sampler = data_utils.samplers.RandomSampler(
dataset, replacement=True, num_samples=batch_size * args.train_iters)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
distributed = world_size > 1
drop_last = distributed
if distributed:
batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler,
batch_size,
drop_last,
rank,
world_size)
else:
batch_sampler = torch.utils.data.BatchSampler(sampler,
batch_size,
drop_last)
data_loader = torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
return data_loader
def make_tfrecord_loaders(args):
"""Load train/val/test dataset from shuffled TFRecords"""
import data_utils.tf_dl
data_set_args = {'batch_size': args.batch_size,
'max_seq_len': args.seq_length,
'max_preds_per_seq': args.max_preds_per_seq,
'train': True,
'num_workers': max(args.num_workers, 1),
'seed': args.seed + args.rank + 1,
'threaded_dl': args.num_workers > 0
}
train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
**data_set_args)
data_set_args['train'] = False
if args.eval_seq_length is not None:
data_set_args['max_seq_len'] = args.eval_seq_length
if args.eval_max_preds_per_seq is not None:
data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
valid = None
if args.valid_data is not None:
valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data,
**data_set_args)
test = None
if args.test_data is not None:
test = data_utils.tf_dl.TFRecordDataLoader(args.test_data,
**data_set_args)
tokenizer = data_utils.make_tokenizer(args.tokenizer_type,
train,
args.tokenizer_path,
args.vocab_size,
args.tokenizer_model_type,
cache_dir=args.cache_dir)
return (train, valid, test), tokenizer
def make_loaders(args):
"""makes training/val/test"""
if args.data_loader == 'tfrecords':
return make_tfrecord_loaders(args)
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
batch_size = args.batch_size * world_size
eval_batch_size = batch_size
if args.eval_batch_size is not None:
eval_batch_size = args.eval_batch_size * world_size
seq_length = args.seq_length
if seq_length < 0:
seq_length = seq_length * world_size
eval_seq_length = args.eval_seq_length
if eval_seq_length is not None and eval_seq_length < 0:
eval_seq_length = eval_seq_length * world_size
split = get_split(args)
if args.data_path is not None:
args.train_data = args.data_path
data_set_args = {
'path': args.train_data,
'seq_length': seq_length,
'lazy': args.data_loader == 'lazy',
'delim': args.delim,
'text_key': args.text_key,
'label_key': 'label',
'non_binary_cols': None,
'ds_type': args.data_set_type,
'split': split,
'loose': args.loose_json,
'tokenizer_type': args.tokenizer_type,
'tokenizer_model_path': args.tokenizer_path,
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences,
'parallel_group': mpu.get_data_parallel_group()}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
# if optional eval args were set then replace their
# equivalent values in the arg dict
if eval_seq_length:
eval_set_args['seq_length'] = eval_seq_length
if args.eval_max_preds_per_seq:
eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq
if args.eval_text_key is not None:
eval_set_args['text_key'] = args.eval_text_key
# make datasets splits and tokenizer
train = None
valid = None
test = None
if args.train_data is not None:
train, tokenizer = data_utils.make_dataset(**data_set_args)
if data_utils.should_split(split):
train, valid, test = train
eval_set_args['tokenizer'] = tokenizer
# make training and val dataset if necessary
if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data
valid, tokenizer = data_utils.make_dataset(**eval_set_args)
eval_set_args['tokenizer'] = tokenizer
if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data
test, tokenizer = data_utils.make_dataset(**eval_set_args)
# wrap datasets with data loader
if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args)
args.do_train = True
else:
args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args)
args.do_valid = True
else:
args.do_valid = False
if test is not None:
test = make_data_loader(test, eval_batch_size, args)
args.do_test = True
else:
args.do_test = False
return (train, valid, test), tokenizer
def get_split(args):
"""
Get dataset splits from comma separated string list
"""
splits = []
if args.split.find(',') != -1:
splits = [float(s) for s in args.split.split(',')]
elif args.split.find('/') != -1:
splits = [float(s) for s in args.split.split('/')]
else:
splits = [float(args.split)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1 - split_total)
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
if args.valid_data is not None:
splits[1] = 0.
if args.test_data is not None:
splits[2] = 0.
final_sum = sum(splits)
return [s / final_sum for s in splits]
def configure_data():
"""add cmdline flags for configuring datasets"""
# These are options that are used by data_utils, but are either
# deprecated or not meant to be exposed to the command line user.
# These options are intneded to be set in code by specific scripts.
defaults = {
'world_size': 1,
'rank': -1,
'persist_state': 0,
'lazy': False,
'transpose': False,
'data_set_type': 'supervised',
'seq_length': 256,
'eval_seq_length': 256,
'samples_per_shard': 100
}
return DataConfig(defaults=defaults)
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""several datasets with preset arguments"""
from .datasets import json_dataset, csv_dataset
import os
class wikipedia(json_dataset):
"""
dataset for wikipedia with arguments configured for convenience
command line usage: `--train-data wikipedia`
"""
PATH = 'data/wikipedia/wikidump_lines.json'
assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(wikipedia.PATH), \
wikipedia.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(wikipedia, self).__init__(wikipedia.PATH, **kwargs)
class webtext(json_dataset):
"""
dataset for webtext with arguments configured for convenience
command line usage: `--train-data webtext`
"""
PATH = 'data/webtext/data.json'
assert_str = "make sure to set PATH for webtext data_utils/corpora.py"
def __init__(self, **kwargs):
assert os.path.exists(webtext.PATH), \
webtext.assert_str
if not kwargs:
kwargs = {}
kwargs['text_key'] = 'text'
kwargs['loose_json'] = True
super(webtext, self).__init__(webtext.PATH, **kwargs)
NAMED_CORPORA = {
'wikipedia': wikipedia,
'webtext': webtext,
}
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset objects for jsons, csvs, and BERT datasets"""
import os
import time
from operator import itemgetter
from bisect import bisect_right
import json
import csv
import math
import random
from itertools import accumulate
from torch.utils import data
import pandas as pd
import numpy as np
import nltk
from nltk import tokenize
from .lazy_loader import lazy_array_loader, exists_lazy, make_lazy
from .tokenization import Tokenization
class ConcatDataset(data.Dataset):
"""
Dataset to concatenate multiple datasets.
Purpose: useful to assemble different existing datasets, possibly
large-scale datasets as the concatenation operation is done in an
on-the-fly manner.
Arguments:
datasets (sequence): List of datasets to be concatenated.
"""
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets, **kwargs):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.is_lazy = sum([isinstance(ds, lazy_array_loader)
for ds in self.datasets]) == len(self.datasets)
self.cumulative_sizes = self.cumsum(self.datasets)
self._X = None
self._Y = None
self._lens = None
def SetTokenizer(self, tokenizer):
for ds in self.datasets:
ds.SetTokenizer(tokenizer)
def GetTokenizer(self):
return self.datasets[0].GetTokenizer()
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def lens(self):
if self._lens is None:
self._lens = []
if self.is_lazy:
for data in self.datasets:
self._lens.extend(data.lens)
else:
for data in self.datasets:
self._lens.extend([len(d['text']) if isinstance(
d, dict) else len(d) for d in data])
return self._lens
@property
def X(self):
if self._X is None:
self._X = []
for data in self.datasets:
self._X.extend(data.X)
return self._X
@property
def Y(self):
if self._Y is None:
self._Y = []
for data in self.datasets:
self._Y.extend(list(data.Y))
self._Y = np.array(self._Y)
return self._Y
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
class SplitDataset(data.Dataset):
"""
Dataset wrapper to access a subset of another dataset.
Purpose: useful to index into existing datasets, possibly
large-scale datasets as the subindexing operation is done in an
on-the-fly manner.
Arguments:
ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset
"""
def __init__(self, ds, split_inds, **kwargs):
self.split_inds = list(split_inds)
self.wrapped_data = ds
self.is_lazy = isinstance(ds, lazy_array_loader) or (hasattr(ds, 'is_lazy') and ds.is_lazy)
if self.is_lazy:
self.lens = itemgetter(*self.split_inds)(list(self.wrapped_data.lens))
self._X = None
self._Y = None
def __len__(self):
return len(self.split_inds)
def __getitem__(self, index):
return self.wrapped_data[self.split_inds[index]]
def SetTokenizer(self, tokenizer):
self.wrapped_data.SetTokenizer(tokenizer)
def GetTokenizer(self):
return self.wrapped_data.GetTokenizer()
@property
def X(self):
if self._X is None:
self._X = itemgetter(*self.split_inds)(self.wrapped_data.X)
return self._X
@property
def Y(self):
if self._Y is None:
self._Y = np.array(itemgetter(*self.split_inds)(self.wrapped_data.Y))
return self._Y
def __iter__(self):
for idx in self.split_inds:
yield self.wrapped_data[idx]
def split_ds(ds, split=[.8, .2, .0], shuffle=True):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
Purpose: Useful for creating train/val/test splits
Arguments:
ds (Dataset or array-like): Data to be split.
split (1D array-like): proportions to split `ds`. `sum(splits) != 0`
shuffle (boolean): Randomly split dataset. Default: True
"""
split_sum = sum(split)
if split_sum == 0:
raise Exception('Split cannot sum to 0.')
split = np.array(split)
split /= split_sum
ds_len = len(ds)
inds = np.arange(ds_len)
if shuffle:
np.random.shuffle(inds)
start_idx = 0
residual_idx = 0
rtn_ds = [None] * len(split)
for i, f in enumerate(split):
if f != 0:
proportion = ds_len * split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
split_inds = inds[start_idx:start_idx + max(split_, 1)]
rtn_ds[i] = SplitDataset(ds, split_inds)
start_idx += split_
residual_idx %= 1
return rtn_ds
class csv_dataset(data.Dataset):
"""
Class for loading datasets from csv files.
Purpose: Useful for loading data for unsupervised modeling or transfer tasks
Arguments:
path (str): Path to csv file with dataset.
tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None
preprocess_fn (callable): Callable that process a string into desired format.
delim (str): delimiter for csv. Default: ','
binarize_sent (bool): binarize label values to 0 or 1 if they\'re on a different scale. Default: False
drop_unlabeled (bool): drop rows with unlabelled values. Always fills remaining empty
columns with -1 (regardless if rows are dropped based on value) Default: False
text_key (str): key to get text from csv. Default: 'sentence'
label_key (str): key to get label from json dictionary. Default: 'label'
Attributes:
X (list): all strings from the csv file
Y (np.ndarray): labels to train with
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',',
binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label',
**kwargs):
self.is_lazy = False
self.preprocess_fn = preprocess_fn
self.SetTokenizer(tokenizer)
self.path = path
self.delim = delim
self.text_key = text_key
self.label_key = label_key
self.drop_unlabeled = drop_unlabeled
if '.tsv' in self.path:
self.delim = '\t'
self.X = []
self.Y = []
try:
cols = [text_key]
if isinstance(label_key, list):
cols += label_key
else:
cols += [label_key]
data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1')
except BaseException:
data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1')
data = data.dropna(axis=0)
self.X = data[text_key].values.tolist()
try:
self.Y = data[label_key].values
except Exception as e:
self.Y = np.ones(len(self.X)) * -1
if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent)
def SetTokenizer(self, tokenizer):
if tokenizer is None:
self.using_tokenizer = False
if not hasattr(self, '_tokenizer'):
self._tokenizer = tokenizer
else:
self.using_tokenizer = True
self._tokenizer = tokenizer
def GetTokenizer(self):
return self._tokenizer
@property
def tokenizer(self):
if self.using_tokenizer:
return self._tokenizer
return None
def __len__(self):
return len(self.X)
def __getitem__(self, index):
"""process+tokenize string and return string,label,and stringlen"""
x = self.X[index]
if self.tokenizer is not None:
x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn)
elif self.preprocess_fn is not None:
x = self.preprocess_fn(x)
y = self.Y[index]
if isinstance(y, str):
if self.tokenizer is not None:
y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn)
elif self.preprocess_fn is not None:
y = self.preprocess_fn(y)
return {'text': x, 'length': len(x), 'label': y}
def write(self, writer_gen=None, path=None, skip_header=False):
"""
given a generator of metrics for each of the data points X_i,
write the metrics, text, and labels to a csv file
"""
if path is None:
path = self.path + '.results'
print('generating csv at ' + path)
with open(path, 'w') as csvfile:
c = csv.writer(csvfile, delimiter=self.delim)
if writer_gen is not None:
# if first item of generator is a header of what the metrics mean then
# write header to csv file
if not skip_header:
header = (self.label_key,) + tuple(next(writer_gen)) + (self.text_key,)
c.writerow(header)
for i, row in enumerate(writer_gen):
row = (self.Y[i],) + tuple(row) + (self.X[i],)
c.writerow(row)
else:
c.writerow([self.label_key, self.text_key])
for row in zip(self.Y, self.X):
c.writerow(row)
class json_dataset(data.Dataset):
"""
Class for loading datasets from a json dump.
Purpose: Useful for loading data for unsupervised modeling or transfer tasks
Arguments:
path (str): path to json file with dataset.
tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None
preprocess_fn (callable): callable function that process a string into desired format.
Takes string, maxlen=None, encode=None as arguments. Default: process_str
text_key (str): key to get text from json dictionary. Default: 'sentence'
label_key (str): key to get label from json dictionary. Default: 'label'
Attributes:
all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it)
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False,
text_key='sentence', label_key='label', loose_json=False, **kwargs):
self.is_lazy = False
self.preprocess_fn = preprocess_fn
self.path = path
self.SetTokenizer(tokenizer)
self.X = []
self.Y = []
self.text_key = text_key
self.label_key = label_key
self.loose_json = loose_json
for j in self.load_json_stream(self.path):
s = j[text_key]
self.X.append(s)
self.Y.append(j[label_key])
if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent)
def SetTokenizer(self, tokenizer):
if tokenizer is None:
self.using_tokenizer = False
if not hasattr(self, '_tokenizer'):
self._tokenizer = tokenizer
else:
self.using_tokenizer = True
self._tokenizer = tokenizer
def GetTokenizer(self):
return self._tokenizer
@property
def tokenizer(self):
if self.using_tokenizer:
return self._tokenizer
return None
def __getitem__(self, index):
"""gets the index'th string from the dataset"""
x = self.X[index]
if self.tokenizer is not None:
x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn)
elif self.preprocess_fn is not None:
x = self.preprocess_fn(x)
y = self.Y[index]
if isinstance(y, str):
if self.tokenizer is not None:
y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn)
elif self.preprocess_fn is not None:
y = self.preprocess_fn(y)
return {'text': x, 'length': len(x), 'label': y}
def __len__(self):
return len(self.X)
def write(self, writer_gen=None, path=None, skip_header=False):
"""
given a generator of metrics for each of the data points X_i,
write the metrics, text, and labels to a json file
"""
if path is None:
path = self.path + '.results'
jsons = []
if writer_gen is not None:
# if first item of generator is a header of what the metrics mean then
# write header to csv file
def gen_helper():
keys = {}
keys[0] = self.label_key
if not skip_header:
for idx, k in enumerate(tuple(next(writer_gen))):
keys[idx + 1] = k
for i, row in enumerate(writer_gen):
if i == 0 and skip_header:
for idx, _ in enumerate(row):
keys[idx + 1] = 'metric_%d' % (idx,)
j = {}
for idx, v in enumerate((self.Y[i],) + tuple(row)):
k = keys[idx]
j[k] = v
yield j
else:
def gen_helper():
for y in self.Y:
j = {}
j[self.label_key] = y
yield j
def out_stream():
for i, j in enumerate(gen_helper()):
j[self.text_key] = self.X[i]
yield j
self.save_json_stream(path, out_stream())
def save_json_stream(self, save_path, json_stream):
if self.loose_json:
with open(save_path, 'w') as f:
for i, j in enumerate(json_stream):
write_string = ''
if i != 0:
write_string = '\n'
write_string += json.dumps(j)
f.write(write_string)
else:
jsons = [j for j in json_stream]
json.dump(jsons, open(save_path, 'w'), separators=(',', ':'))
def load_json_stream(self, load_path):
if not self.loose_json:
jsons = json.load(open(load_path, 'r'))
generator = iter(jsons)
else:
def gen_helper():
with open(load_path, 'r') as f:
for row in f:
yield json.loads(row)
generator = gen_helper()
for j in generator:
if self.label_key not in j:
j[self.label_key] = -1
yield j
class GPT2Dataset(data.Dataset):
def __init__(self, ds,
max_seq_len=1024,
num_samples=None,
weighted=True,
sample_across_doc=True,
random_across_doc_sampling=True,
bias_for_single_doc=False,
sentence_start=False, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.num_samples = num_samples
if num_samples is None:
self.num_samples = 1000 * self.ds_len
self.max_seq_len = max_seq_len
self.tokenizer = self.ds.GetTokenizer()
self.ds.SetTokenizer(None)
self.weighted = weighted
self.sample_across_doc = sample_across_doc
self.random_across_doc_sampling = random_across_doc_sampling
self.bias_for_single_doc = bias_for_single_doc
self.sentence_start = sentence_start
self.init_weighting()
def init_weighting(self):
if self.weighted:
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens)
else:
lens = np.array([len(d['text']) if isinstance(d, dict)
else len(d) for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
self.weighting = None
def get_weighted_samples(self, np_rng):
if self.weighting is not None:
idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx)
else:
return np_rng.randint(self.ds_len)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# init rng
rng = random.Random(idx)
rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get possibly weighted random index from dataset
data_idx = self.get_weighted_samples(rng)
# data_idx = rng.choice(self.ds_len, p=self.weighting)
tokens = self.getidx(data_idx)
# truncate or pad tokens
num_tokens = len(tokens)
if self.bias_for_single_doc:
tokens_to_strip = num_tokens - self.max_seq_len - 1
else:
tokens_to_strip = num_tokens - 1
if tokens_to_strip > 0:
strip_left_tokens = rng.randint(tokens_to_strip + 1)
tokens = tokens[strip_left_tokens:]
if self.sentence_start:
token_copy = list(tokens)
not_done = True
while (len(token_copy) > 0) and not_done:
tok = token_copy.pop(0)
if self.contains_sentence_end(tok):
tokens = token_copy
not_done = False
strip_right_rokens = len(tokens) - self.max_seq_len - 1
if strip_right_rokens > 0:
tokens = tokens[:-strip_right_rokens]
if self.sample_across_doc:
while (len(tokens) < (self.max_seq_len + 1)):
if self.random_across_doc_sampling:
data_idx = self.get_weighted_samples(rng)
else:
data_idx = (data_idx + 1) % self.ds_len
tokens += self.getidx(data_idx)
tokens = tokens[:(self.max_seq_len + 1)]
tokens = self.pad_seq(tokens)
return {'text': np.array(tokens), }
def getidx(self, data_idx):
data = self.ds[data_idx]
if isinstance(data, dict):
data = data['text']
# tokenize
tokenization = self.tokenizer.EncodeAsIds(data)
tokenization.append(self.tokenizer.get_command('eos'))
tokens = tokenization.tokenization
return tokens
def pad_seq(self, seq):
total_tokens = self.max_seq_len + 1
num_pad_tokens = max(0, total_tokens - len(seq))
seq += [self.tokenizer.get_command('pad').Id] * (num_pad_tokens)
return seq
def contains_sentence_end(self, tok):
tok = self.tokenizer.IdToToken(tok)
if '.' in tok:
return True
if '?' in tok:
return True
if '!' in tok:
return True
return False
class bert_sentencepair_dataset(data.Dataset):
"""
Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair.
Arguments:
ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for a sentence pair
mask_lm_prob (float): proportion of tokens to mask for masked LM
max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10
short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None,
short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
self.vocab_words = list(self.tokenizer.text_token_vocab.values())
self.ds.SetTokenizer(None)
self.max_seq_len = max_seq_len
self.mask_lm_prob = mask_lm_prob
if max_preds_per_seq is None:
max_preds_per_seq = math.ceil(max_seq_len * mask_lm_prob / 10) * 10
self.max_preds_per_seq = max_preds_per_seq
self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size
if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len - 1)
self.presplit_sentences = presplit_sentences
if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk")
self.weighted = weighted
self.get_weighting()
def get_weighting(self):
if self.weighted:
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens)
else:
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d)
for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
self.weighting = None
def get_weighted_samples(self, np_rng):
if self.weighting is not None:
idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx)
else:
return np_rng.randint(self.ds_len)
def __len__(self):
return self.dataset_size
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get seq length
target_seq_length = self.max_seq_len
short_seq = False
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length)
short_seq = True
# get sentence pair and label
is_random_next = None
lena = 0
lenb = 0
while (is_random_next is None) or (lena < 1) or (lenb < 1):
tokensa, tokensb, is_random_next = self.create_random_sentencepair(
target_seq_length, rng, np_rng)
lena = len(tokensa[0])
lenb = len(tokensb[0])
# truncate sentence pair to max_seq_len
tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng)
# join sentence pair, mask, and pad
tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(
tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng)
sample = {
'text': np.array(
tokens[0]),
'types': np.array(
tokens[1]),
'is_random': int(is_random_next),
'mask': np.array(mask),
'mask_labels': np.array(mask_labels),
'pad_mask': np.array(pad_mask)}
return sample
def sentence_split(self, document):
"""split document into sentences"""
lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn
def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id] * len(tokens)
return tokens, token_types
def get_doc(self, idx):
"""gets text of document corresponding to idx"""
rtn = self.ds[idx]
if isinstance(rtn, dict):
rtn = rtn['text']
return rtn
def create_random_sentencepair(self, target_seq_length, rng, np_rng):
"""
fetches a random sentencepair corresponding to rng state similar to
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L248-L294
"""
is_random_next = None
curr_strs = []
curr_str_types = []
curr_len = 0
while curr_len < 1:
curr_len = 0
doc_a = None
while doc_a is None:
if self.weighted:
# doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting)
doc_a_idx = self.get_weighted_samples(np_rng)
else:
doc_a_idx = rng.randint(0, self.ds_len - 1)
doc_a = self.sentence_split(self.get_doc(doc_a_idx))
if not doc_a:
doc_a = None
random_start_a = rng.randint(0, len(doc_a) - 1)
while random_start_a < len(doc_a):
sentence = doc_a[random_start_a]
sentence, sentence_types = self.sentence_tokenize(
sentence, 0, random_start_a == 0, random_start_a == len(doc_a))
curr_strs.append(sentence)
curr_str_types.append(sentence_types)
curr_len += len(sentence)
if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length:
break
random_start_a = (random_start_a + 1)
if curr_strs:
num_a = 1
if len(curr_strs) >= 2:
num_a = rng.randint(0, len(curr_strs))
tokens_a = []
token_types_a = []
for j in range(num_a):
tokens_a.extend(curr_strs[j])
token_types_a.extend(curr_str_types[j])
tokens_b = []
token_types_b = []
is_random_next = False
if len(curr_strs) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
b_len = 0
while b_len < 1:
doc_b = None
while doc_b is None:
doc_b_idx = rng.randint(0, self.ds_len - 2)
doc_b_idx += int(doc_b_idx >= doc_a_idx)
doc_b = self.sentence_split(self.get_doc(doc_b_idx))
if not doc_b:
doc_b = None
random_start_b = rng.randint(0, len(doc_b) - 1)
while random_start_b < len(doc_b):
sentence_b = doc_b[random_start_b]
new_b_tokens, new_b_types = self.sentence_tokenize(
sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b))
b_len += len(new_b_tokens)
tokens_b.extend(new_b_tokens)
token_types_b.extend(new_b_types)
if len(tokens_b) >= target_b_length:
break
random_start_b = (random_start_b + 1)
else:
is_random_next = False
for j in range(num_a, len(curr_strs)):
tokens_b.extend(curr_strs[j])
token_types_b.extend(curr_str_types[j])
return (tokens_a, token_types_a), (tokens_b, token_types_b), is_random_next
def truncate_seq_pair(self, a, b, max_seq_len, rng):
"""
Truncate sequence pair according to original BERT implementation:
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
max_num_tokens = self.calc_seq_len(max_seq_len)
# max_num_tokens = max_seq_len - 3
while True:
len_a = len(tokens_a)
len_b = len(tokens_b)
total_length = len_a + len_b
if total_length <= max_num_tokens:
break
if len(tokens_a) > len(tokens_b):
trunc_tokens = tokens_a
trunc_types = token_types_a
else:
trunc_tokens = tokens_b
trunc_types = token_types_b
assert len(trunc_tokens) >= 1
if rng.random() < 0.5:
trunc_tokens.pop(0)
trunc_types.pop(0)
else:
trunc_tokens.pop()
trunc_types.pop()
return (tokens_a, token_types_a), (tokens_b, token_types_b)
def calc_seq_len(self, max_seq_len):
return max_seq_len - 3
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label = tokens[idx]
if rng.random() < 0.8:
new_label = self.tokenizer.get_command('MASK').Id
else:
if rng.random() < 0.5:
new_label = label
else:
new_label = rng.choice(vocab_words)
tokens[idx] = new_label
return label
def pad_seq(self, seq):
"""helper function to pad sequence pair"""
num_pad = max(0, self.max_seq_len - len(seq))
pad_mask = [0] * len(seq) + [1] * num_pad
seq += [self.tokenizer.get_command('pad').Id] * num_pad
return seq, pad_mask
def concat_tokens(self, tokens_a, token_types_a, tokens_b, token_types_b):
tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command(
'sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id]
token_types = [token_types_a[0]] + token_types_a + \
[token_types_a[0]] + token_types_b + [token_types_b[0]]
return tokens, token_types
def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng):
"""
Mask sequence pair for BERT training according to:
https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L338
"""
tokens_a, token_types_a = a
tokens_b, token_types_b = b
tokens, token_types = self.concat_tokens(tokens_a, token_types_a, tokens_b, token_types_b)
len_a = len(tokens_a)
len_b = len(tokens_b)
cand_indices = [idx + 1 for idx in range(len_a)] + [idx + 2 + len_a for idx in range(len_b)]
rng.shuffle(cand_indices)
output_tokens, pad_mask = self.pad_seq(list(tokens))
output_types, _ = self.pad_seq(list(token_types))
num_to_predict = min(max_preds_per_seq, max(1, int(round(len(tokens) * mask_lm_prob))))
mask = [0] * len(output_tokens)
mask_labels = [-1] * len(output_tokens)
for idx in sorted(cand_indices[:num_to_predict]):
mask[idx] = 1
label = self.mask_token(idx, output_tokens, output_types, vocab_words, rng)
mask_labels[idx] = label
return (output_tokens, output_types), mask, mask_labels, pad_mask
# This file is provided as is from:
# https://github.com/huggingface/pytorch-pretrained-BERT
# Please refer to their repository for copyright.
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import json
import logging
import os
import shutil
import tempfile
from functools import wraps
from hashlib import sha256
import sys
from io import open
import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils for loading text from disk"""
import os
import mmap
import pickle as pkl
import time
from itertools import accumulate
import torch
from torch.multiprocessing import Lock
def get_lazy_path(path):
"""
Gets directory path where lazy files are stored.
"""
return os.path.splitext(path)[0] + '.lazy'
def exists_lazy(path, data_type='data'):
"""
Check if we've already made a lazy version of this file for the `data_type` field.
"""
if not os.path.exists(get_lazy_path(path)):
return False
contents = os.listdir(get_lazy_path(path))
if data_type not in contents:
return False
if data_type + '.len.pkl' not in contents:
return False
return True
def make_lazy(path, strs, data_type='data'):
"""
Make lazy version of `data_type` field of the file. Byte offsets
corresponding to data indices are stored in a `.len.pkl` data file.
"""
lazypath = get_lazy_path(path)
if not os.path.exists(lazypath):
os.makedirs(lazypath)
datapath = os.path.join(lazypath, data_type)
lenpath = os.path.join(lazypath, data_type + '.len.pkl')
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
with open(datapath, 'wb') as f:
str_lens = []
str_cnt = 0
for s in strs:
if isinstance(s, dict):
s = s['text']
encoded = s.encode('utf-8')
f.write(encoded)
str_cnt = len(encoded)
str_lens.append(str_cnt)
pkl.dump(str_lens, open(lenpath, 'wb'))
else:
while not os.path.exists(lenpath):
time.sleep(1)
def split_strings(strings, start, chr_lens):
"""
Split strings based on string lengths and given start.
"""
return [strings[i - start:j - start] for i, j in zip([start] + chr_lens[:-1], chr_lens)]
class ProcessorTokenizer:
"""
callable class that runs a preprocessing, as well as tokenization step,
on input text.
"""
def __init__(self, tokenizer, process_fn=None):
self.tokenizer = tokenizer
self.process_fn = process_fn
def __call__(self, string):
if self.tokenizer is not None:
string = self.tokenizer(string, process_fn=self.process_fn)
elif self.process_fn is not None:
string = self.process_fn(string)
return string
class lazy_array_loader(object):
"""
Arguments:
path: path to directory where array entries are concatenated into one big string file
and the .len file are located
data_type (str): Some datsets have multiple fields that are stored in different paths.
`data_type` specifies which of these fields to load in this class
mem_map (boolean): Specifies whether to memory map file `path`
map_fn (callable): Fetched strings are passed through map_fn before being returned.
Example of lazy loader directory structure:
file.json
file.lazy/
data_type1
data_type1.len.pkl
data_type2
data_type2.len.pkl
"""
def __init__(self, path, data_type='data', mem_map=False, map_fn=None):
lazypath = get_lazy_path(path)
datapath = os.path.join(lazypath, data_type)
# get file where array entries are concatenated into one big string
self._file = open(datapath, 'rb', buffering=0)
self.file = self._file
# memory map file if necessary
self.mem_map = mem_map
if self.mem_map:
self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ)
lenpath = os.path.join(lazypath, data_type + '.len.pkl')
self.lens = pkl.load(open(lenpath, 'rb'))
self.ends = list(accumulate(self.lens))
self.dumb_ends = list(self.ends)
self.read_lock = Lock()
self.process_fn = map_fn
self.map_fn = map_fn
self._tokenizer = None
def SetTokenizer(self, tokenizer):
"""
logic to set and remove (set to None) tokenizer.
combines preprocessing/tokenization into one callable.
"""
if tokenizer is None:
if not hasattr(self, '_tokenizer'):
self._tokenizer = tokenizer
else:
self._tokenizer = tokenizer
self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn)
def GetTokenizer(self):
return self._tokenizer
def __getitem__(self, index):
"""
read file and splice strings based on string ending array `self.ends`
"""
if not isinstance(index, slice):
if index == 0:
start = 0
else:
start = self.ends[index - 1]
end = self.ends[index]
rtn = self.file_read(start, end)
if self.map_fn is not None:
return self.map_fn(rtn)
else:
# if slice, fetch strings with 1 diskread and then splice in memory
chr_lens = self.ends[index]
if index.start == 0 or index.start is None:
start = 0
else:
start = self.ends[index.start - 1]
stop = chr_lens[-1]
strings = self.file_read(start, stop)
rtn = split_strings(strings, start, chr_lens)
if self.map_fn is not None:
return self.map_fn([s for s in rtn])
return rtn
def __len__(self):
return len(self.ends)
def file_read(self, start=0, end=None):
"""read specified portion of file"""
# atomic reads to avoid race conditions with multiprocess dataloader
self.read_lock.acquire()
# seek to start of file read
self.file.seek(start)
# read to end of file if no end point provided
if end is None:
rtn = self.file.read()
# else read amount needed to reach end point
else:
rtn = self.file.read(end - start)
self.read_lock.release()
# TODO: @raulp figure out mem map byte string bug
# if mem map'd need to decode byte string to string
rtn = rtn.decode('utf-8', 'ignore')
# rtn = str(rtn)
if self.mem_map:
rtn = rtn.decode('unicode_escape')
return rtn
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""batch samplers that work with either random or sequential data samplers"""
import math
import os
import sys
import torch
from torch.utils import data
import numpy as np
class RandomSampler(data.sampler.Sampler):
r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
but this class lets the user set an epoch like DistributedSampler
Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,),
dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler):
"""
similar to normal implementation of distributed sampler, except implementation is at the
batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
"""
def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
if rank == -1:
assert False, 'should not be here'
rank = torch.distributed.get_rank()
self.rank = rank
self.world_size = world_size
self.sampler.wrap_around = 0
self.wrap_around = 0
self.wrap_last = wrap_last
self.start_iter = 0
def __iter__(self):
batch = []
last_batch = None
i = 0
for idx in self.data_iterator(self.sampler, wrap_around=False):
batch.append(idx)
if len(batch) == self.batch_size:
tbatch = self._batch(batch)
if i >= self.start_iter:
yield tbatch
self.start_iter = 0
i += 1
last_batch = np.array(list(tbatch))
batch = []
batch_len = len(batch)
if batch_len > 0 and not self.drop_last:
if self.wrap_last:
self.sampler.wrap_around -= (self.batch_size)
self.wrap_around += (len(batch))
self.wrap_around %= self.batch_size
if isinstance(self.sampler, TransposedSampler):
for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)):
if i == 0:
continue
batch.append(idx)
new_batch_len = len(batch)
if len(batch) == self.batch_size:
break
yield self._batch(batch)
if self.wrap_last:
self.sampler.wrap_around += self.batch_size
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around % self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
self.wrap_around %= self.batch_size
yield idx
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end]
"""
Usage:
python scripts/presplit_sentences_json.py <original loose json file> <output loose json file>
"""
import sys
import json
import nltk
nltk.download('punkt')
input_file = sys.argv[1]
output_file = sys.argv[2]
line_seperator = "\n"
with open(input_file, 'r') as ifile:
with open(output_file, "w") as ofile:
for doc in ifile.readlines():
parsed = json.loads(doc)
sent_list = []
for line in parsed['text'].split('\n'):
if line != '\n':
sent_list.extend(nltk.tokenize.sent_tokenize(line))
parsed['text'] = line_seperator.join(sent_list)
ofile.write(json.dumps(parsed) + '\n')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i] * len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip() + '\t' + str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l + '\n')
def write_mapping_file(m, path):
path = path + '.map'
m = [get_mapping_header()] + m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
# calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent * total_lines)
test_percent = 0
if len(args.test_percent) == 2:
test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines - (test_lines + dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i] * len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip() + '\t' + str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l + '\n')
def write_mapping_file(m, path):
path = path + '.map'
m = [get_mapping_header()] + m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
# calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent * total_lines)
test_percent = 0
if len(args.test_percent) == 2:
test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines - (test_lines + dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import numpy as np
import torch
import queue
import threading
import tensorflow as tf
tf.enable_eager_execution()
class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq,
train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf.set_random_seed(seed)
if isinstance(records, str):
records = [records]
self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
"input_mask": tf.FixedLenFeature([max_seq_len], tf.int64),
"segment_ids": tf.FixedLenFeature([max_seq_len], tf.int64),
"masked_lm_positions": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
"masked_lm_ids": tf.FixedLenFeature([max_preds_per_seq], tf.int64),
"masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32),
"next_sentence_labels": tf.FixedLenFeature([1], tf.int64)})
# Instantiate dataset according to original BERT implementation
if train:
self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records))
self.dataset = self.dataset.repeat()
self.dataset = self.dataset.shuffle(buffer_size=len(records))
# use sloppy tfrecord dataset
self.dataset = self.dataset.apply(
tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=train,
cycle_length=min(num_workers, len(records))))
self.dataset = self.dataset.shuffle(buffer_size=100)
else:
self.dataset = tf.data.TFRecordDataset(records)
self.dataset = self.dataset.repeat()
# Instantiate dataloader (do not drop remainder for eval)
loader_args = {'batch_size': batch_size,
'num_parallel_batches': num_workers,
'drop_remainder': train}
self.dataloader = self.dataset.apply(
tf.contrib.data.map_and_batch(
self.record_converter, **loader_args))
self.threaded_dl = threaded_dl
self.num_workers = num_workers
def __iter__(self):
if self.threaded_dl:
data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
for item in data_iter:
yield item
else:
data_iter = iter(self.dataloader)
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)
class Record2Example(object):
def __init__(self, feature_map):
self.feature_map = feature_map
def __call__(self, record):
"""Decodes a BERT TF record to a TF example."""
example = tf.parse_single_example(record, self.feature_map)
for k, v in list(example.items()):
if v.dtype == tf.int64:
example[k] = tf.to_int32(v)
return example
def convert_tf_example_to_torch_tensors(example):
item = {k: (v.numpy()) for k, v in example.items()}
mask = np.zeros_like(item['input_ids'])
mask_labels = np.ones_like(item['input_ids']) * -1
for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i]
output = {'text': item['input_ids'], 'types': item['segment_ids'], 'is_random': item['next_sentence_labels'],
'pad_mask': 1 - item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k, v in output.items()}
class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2):
self.dl = dataloader
self.queue_size = 2 * num_workers
def __iter__(self):
output_queue = queue.Queue(self.queue_size)
output_thread = threading.Thread(target=_multiproc_iter,
args=(self.dl, output_queue))
output_thread.daemon = True
output_thread.start()
while output_thread.is_alive():
yield output_queue.get(block=True)
else:
print(RuntimeError('TF record data loader thread exited unexpectedly'))
def _multiproc_iter(dl, output_queue):
data_iter = iter(dl)
for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)"""
from collections import namedtuple
import random
import os
import csv
import torch
import nltk
from nltk import tokenize as nltk_tokenize
import sentencepiece as spm
from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from .tokenization_gpt2 import GPT2Tokenizer
import regex as re
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe',
pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
"""
Helper function to instantiate a tokenizer given common combinations of options.
"""
tokenizer_class = tokenizer_type
if isinstance(tokenizer_class, str):
tokenizer_class = eval(tokenizer_class)
if tokenizer_class is BertWordPieceTokenizer:
return BertWordPieceTokenizer(model_type, **kwargs)
elif tokenizer_class is GPT2BPETokenizer:
return GPT2BPETokenizer(**kwargs)
text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type,
pad_token=pad_token, character_coverage=character_coverage)
return Tokenizer(text_tokenizer, command_tokens, type_tokens)
class Tokenization(object):
"""
Tokenization object to hold tokenization, (processed text),and original
text. Can hold tokenization as Ids or tokens.
It also holds command tokens (pad, unk, etc.) for the tokenization.
This allows functions to pad/operate on tokenizations without having
access to the full tokenizer, just the tokenization.
Several standard array operations are implemented (insert, append, extend).
"""
def __init__(self, tokenization, text=None, original_text=None,
command_tokens=None, asIds=True):
self.tokenization = tokenization
self.text = text
if self.text is None:
self.text = self.tokenization
self.original_text = original_text
if self.original_text is None:
self.original_text = self.text
self.command_tokens = command_tokens
self.asIds = asIds
self.parse_command_tokens()
def set_command_tokens(self, command_tokens):
self.command_tokens = command_tokens
return self.parse_command_tokens()
def parse_command_tokens(self):
if self.command_tokens is None:
return
for command_token in self.command_tokens:
if self.asIds:
setattr(self, command_token.name, command_token.Id)
else:
setattr(self, command_token.name, command_token.token)
def __getitem__(self, index):
return self.tokenization[index]
def __len__(self):
return len(self.tokenization)
def insert(self, idx, other):
if isinstance(other, (CommandToken, TypeToken)):
self.tokenization.insert(idx, other.Id)
if idx == 0:
self.text = other.token + self.text
self.original_text = other.token + self.original_text
elif idx == len(self.tokenization) - 1:
self.text += other.token
self.original_text += other.token
elif isinstance(other, Tokenization):
self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
else:
self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
def append(self, other):
if isinstance(other, (CommandToken, TypeToken)):
self.tokenization.append(other.Id)
self.text += other.token
self.original_text += other.token
elif isinstance(other, Tokenization):
self.tokenization.extend(other.tokenization)
self.text += other.text
self.original_text += other.original_text
else:
self.tokenization.append(other)
return self
def extend(self, other):
if isinstance(other, (CommandToken, TypeToken)):
self.tokenization.append(other.Id)
self.text += other.token
self.original_text += other.token
elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)):
self.tokenization.extend([o.Id for o in other])
self.text += [o.token for o in other]
self.original_text += [o.token for o in other]
elif isinstance(other, Tokenization):
self.tokenization.extend(other.tokenization)
self.text += other.text
self.original_text += other.original_text
else:
self.tokenization.extend(other)
return self
"""define some default command tokens for the tokenizer to use"""
token_format = "<{0}>"
COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
def prep_command_tokens(tokenlist, token_format=token_format):
return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class CommandToken(object):
def __init__(self, name, token, Id):
self.name = name
self.token = token
self.Id = Id
def __str__(self):
return str(COMMAND_TUPLE(self.name, self.token, self.Id))
DEFAULT_COMMAND_TOKENS = [
('pad', 0),
('eos', 1),
('bos', 2),
('unk', 3),
('sep', 4),
('L2R', 5),
('ENC', 6),
('MASK', 7),
]
DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
"""define some default type tokens for bert training"""
TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
def prep_type_tokens(tokenlist, token_format=token_format):
return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class TypeToken(object):
def __init__(self, name, token, Id):
self.name = name
self.token = token
self.Id = Id
def __str__(self):
return str(TYPE_TUPLE(self.name, self.token, self.Id))
DEFAULT_TYPE_TOKENS = [
('function', 0),
('command', 1),
('str0', 2),
('str1', 3),
('str2', 4),
('embedding0', 5),
('embedding1', 6),
('embedding2', 7),
('arg0', 8),
('arg1', 9),
('arg2', 10),
]
DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
class Tokenizer(object):
"""
Tokenizer object that handles text tokenization, command tokens, and type tokens.
Command tokens and text tokens are stored together in one mapping of size
`len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first
`len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`.
Token types are stored in a separate mapping of size `len(type_tokens)`.
"""
def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
# set text tokenizer
self.text_tokenizer = text_tokenizer
if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = len(self.text_tokenizer)
# set command tokens
if command_tokens is None:
command_tokens = DEFAULT_COMMAND_TOKENS
self._command_tokens = command_tokens
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
if not hasattr(self, 'num_command_tokens'):
self.num_command_tokens = len(self._command_tokens)
if not hasattr(self, 'num_tokens'):
self.num_tokens = self.num_command_tokens + self.num_text_tokens
# set type tokens
if type_tokens is None:
type_tokens = DEFAULT_TYPE_TOKENS
self.type_tokens = type_tokens
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
if not hasattr(self, 'num_type_tokens'):
self.num_type_tokens = len(self.type_tokens)
# parse tokens and vocabs from tokenizer
self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
self._vocab = {t: Id for Id, t in self.command_id_map.items()}
self._vocab.update({t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()})
self._text_tokens = list(self.text_tokenizer.tokens)
self._text_token_vocab = {
t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def __call__(self, text, process_fn=None):
"""run preprocessing and encode text as Ids"""
return self.EncodeAsIds(text, process_fn=process_fn)
def __len__(self):
"""total number of tokens"""
return self.num_tokens
def get_command(self, name):
"""get command token corresponding to `name`"""
return self.command_name_map[name]
def get_type(self, name):
"""get type token corresponding to `name`"""
return self.type_name_map[name]
@property
def tokens(self):
"""list (or iterable) of all tokens for tokenizer"""
return self._tokens
@property
def vocab(self):
"""dictionary mapping tokens to ids for tokenizer"""
return self._vocab
@property
def token_types(self):
"""list (or iterable) of all token types for tokenizer"""
return self._token_types
@property
def token_type_vocab(self):
"""dictionary mapping token types to ids for tokenizer"""
return self._token_type_vocab
@property
def command_tokens(self):
"""list (or iterable) of all command tokens for tokenizer"""
return self._command_token_tokens
@property
def command_token_vocab(self):
"""dictionary mapping command tokens to ids for tokenizer"""
return self._command_token_vocab
@property
def text_tokens(self):
"""list (or iterable) of text tokens for text tokenizer"""
return self._text_tokens
@property
def text_token_vocab(self):
"""dictionary mapping text tokens to ids for text tokenizer"""
return self._text_token_vocab
def EncodeAsIds(self, text, process_fn=None):
"""
encode text using text tokenizer and shift Id values for command tokens
"""
tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn)
tokenization.tokenization = [t + self.num_command_tokens for t in tokenization.tokenization]
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
"""
encode text as tokens using text tokenizer
"""
tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def IdToToken(self, Id, type_token=False):
"""convert Id to token accounting for command and type tokens"""
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
if Id < self.num_command_tokens:
return self.command_id_map[Id].token
return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
def TokenToId(self, token, type_token=False):
"""convert token to Id accounting for command and type tokens"""
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
if token in self.command_token_map:
return self.command_token_map[token].Id
return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
def DecodeIds(self, Ids, type_token=False):
"""
convert Ids to tokens accounting for command and type tokens, tokens
are joined and returned as a string.
"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
rtn_strs = []
current_str = []
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
for Id in Ids:
if isinstance(Id, CommandToken):
rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
current_str = []
rtn_strs.append(t.token)
elif Id < self.num_command_tokens:
rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
current_str = []
rtn_strs.append(self.command_id_map[Id].token)
else:
current_str.append(Id - self.num_command_tokens)
if current_str != []:
rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
return ' '.join(rtn_strs)
def DecodeTokens(self, Tokens, type_token=False):
"""
convert tokens to a string accounting for command and type tokens.
"""
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
rtn_strs = []
current_str = []
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
for t in Tokens:
if isinstance(t, CommandToken):
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
current_str = []
rtn_strs.append(t.token)
elif t in self.command_token_map:
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
current_str = []
rtn_strs.append(t)
else:
current_str.append(t)
if current_str != []:
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
return ' '.join(rtn_strs)
class TextTokenizer(object):
"""
Interface for text tokenizer
"""
def __init__(self):
if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = 0
if not hasattr(self, 'num_tokens'):
self.num_tokens = self.num_text_tokens
def __call__(self, text, process_fn=None):
return self.EncodeAsIds(text, process_fn)
def __len__(self):
return self.num_text_tokens
@property
def tokens(self):
"""list (or iterable) of text tokens for text tokenizer"""
raise NotImplementedError('TextTokenizer tokens property not implemented')
@property
def vocab(self):
"""dictionary mapping tokens to ids"""
raise NotImplementedError('TextTokenizer vocab property not implemented')
@staticmethod
def exists(model_path):
"""check if the filepath for a text tokenizer exists"""
raise NotImplementedError('TextTokenizer exists method not implemented')
def Train(self, corpus):
"""train a tokenizer on a data corpus and save model for future use"""
raise NotImplementedError('TextTokenizer Train not implemented')
def EncodeAsIds(self, text, process_fn=None):
"""
Preprocess text and encode as ids. Return a tokenization object with
original text, processed text, and id tokenization.
"""
raise NotImplementedError('TextTokenizer EncodeAsIds not implemented')
def EncodeAsTokens(self, text, process_fn=None):
"""
Preprocess text and encode as tokens. Return a tokenization object with
original text, processed text, and token tokenization.
"""
raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented')
def IdToToken(self, Id):
"""Convert an Id to Token. Reverse lookup of self.vocab"""
raise NotImplementedError('TextTokenizer IdToToken not implemented')
def TokenToId(self, token):
"""Convert a Token to Id. Lookup of self.vocab"""
raise NotImplementedError('TextTokenizer TokenToId not implemented')
def DecodeIds(self, Ids):
"""Convert a list or tokenization object of Ids to a text string"""
raise NotImplementedError('TextTokenizer DecodeIds not implemented')
def DecodeTokens(self, Tokens):
"""Convert a list or tokenization object of tokens to a text string"""
raise NotImplementedError('TextTokenizer DecodeTokens not implemented')
class CharacterLevelTokenizer(TextTokenizer):
"""
Text tokenizer for ASCII-256 Character Level Tokenization.
"""
def __init__(self, **kwargs):
self.num_text_tokens = 256
super(CharacterLevelTokenizer, self).__init__()
self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
self._vocab = {t: i for i, t in enumerate(self._tokens)}
def __len__(self):
return 256
@staticmethod
def exists(model_path):
return True
def Train(self, corpus):
pass
@property
def tokens(self):
return self._tokens
@property
def vocab(self):
return self._vocab
def EncodeAsIds(self, text, process_fn=None):
"""convert text to ascii 256 Ids"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
processed_text = str(processed_text)
tokens = [self.TokenToId(c) for c in processed_text]
return Tokenization(tokens, processed_text, text)
def EncodeAsTokens(self, text, process_fn=None):
"""convert text to ascii 256 characters"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
processed_text = str(processed_text)
tokens = [c for c in processed_text]
return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id):
"""ascii index to character"""
return chr(Id)
def TokenToId(self, token):
"""ascii character to index"""
return ord(token)
def DecodeIds(self, Ids):
"""converts ascii ids to tokens before joining them into text"""
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return ''.join([self.IdToToken(tok) for tok in Ids])
def DecodeTokens(self, Tokens):
"""just concatenates ascii tokens into text"""
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return ''.join(Tokens)
MAX_SENTENCEPIECE_SENTENCES = 100000000
def get_corpus_freq(dataset, filepath, filetype='tsv'):
"""
Take corpus, split it into sentences, and extract word frequencies.
Write frequencies to `filepath` as a tsv. Only write the first
MAX_SENTENCEPIECE_SENTENCES most common words to the file.
"""
nltk.download('punkt', download_dir="./nltk")
if filetype == 'tsv':
delimiter = '\t'
else:
delimiter = ','
print("compute corpus frequency\n", flush=True)
total_sentence_count = 0
maxlen = 0
freqs = {}
for entry in dataset:
if isinstance(entry, dict):
entry = entry['text']
lines = entry.strip().split('\n')
for line in lines:
sentences = nltk_tokenize.sent_tokenize(line)
total_sentence_count += len(sentences)
for sentence in sentences:
maxlen = max(len(line), maxlen)
for word in sentence.split():
if word not in freqs:
freqs[word] = 0
freqs[word] += 1
print("length of freqs before truncating " + str(len(freqs)), flush=True)
print("file path for freq " + str(filepath), flush=True)
freqs_sorted = {}
counter = 0
for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
if counter >= MAX_SENTENCEPIECE_SENTENCES:
break
counter += 1
freqs_sorted[word] = count
print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
with open(filepath, 'w') as f:
writer = csv.writer(f, delimiter=delimiter)
for k, v in freqs_sorted.items():
writer.writerow([str(k), str(v)])
return total_sentence_count, maxlen
class SentencePieceTokenizer(TextTokenizer):
"""Trains and uses sentencepiece for text tokenization"""
def __init__(self, model_type='bpe', vocab_size=None, corpus=None,
model_path=None, character_coverage=1.0, **kwargs):
self.character_coverage = character_coverage
self.model_type = model_type.lower()
self.spm_model = model_path
self.num_text_tokens = vocab_size
make_train = not SentencePieceTokenizer.exists(self.spm_model)
if make_train:
assert corpus is not None and self.num_text_tokens is not None
self.Train(corpus, self.num_text_tokens)
self._tokens = []
self._vocab = {}
self.load_spm_model()
super(SentencePieceTokenizer, self).__init__()
def __len__(self):
return self.num_text_tokens
@property
def tokens(self):
return self._tokens
@property
def vocab(self):
return self._vocab
@staticmethod
def exists(model_path):
if model_path is None:
return False
# check if path exists
dne = not os.path.exists(model_path)
# check if path.model exists
if dne and not model_path.endswith('.model'):
dne = not os.path.exists(model_path + '.model')
return not dne
def load_spm_model(self):
"""load sentencepiece model and parse vocab"""
if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
self.spm_model = self.spm_model + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model)
self.vocab_size = self.num_text_tokens = len(self.sp)
self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
self._vocab = {t: i for i, t in enumerate(self._tokens)}
def Train(self, corpus, num_text_tokens):
"""train sentencepiece model on corpus using word frequencies"""
self.num_text_tokens = num_text_tokens
use_model_path = self.spm_model
random_hash = str(random.randint(0, 2147483647))
if use_model_path is None:
use_model_path = random_hash
if use_model_path.endswith('.model'):
use_model_path = use_model_path[:use_model_path.rfind('.model')]
input_path = use_model_path + '.tsv.' + random_hash
line_count, maxlenline = get_corpus_freq(corpus, input_path)
line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
print('line count used as input_sentence_size ', line_count, flush=True)
print('training sentencepiece model', flush=True)
train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \
+ ' --model_type={model_type} --character_coverage={character_coverage} ' \
+ '--input_sentence_size={input_sentence_size} ' \
+ '--input_format=tsv'
train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens,
model_type=self.model_type, character_coverage=self.character_coverage,
input_sentence_size=int(line_count)) # , #)#,
print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
spm.SentencePieceTrainer.Train(train_string)
os.remove(input_path)
self.spm_model = use_model_path + '.model'
print('sentencepiece model written to ' + self.spm_model, flush=True)
def EncodeAsIds(self, text, process_fn=None):
"""convert text to sentencepiece Ids"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.sp.EncodeAsIds(processed_text)
return Tokenization(tokens, processed_text, text)
def EncodeAsTokens(self, text, process_fn=None):
"""convert text to sentencepiece tokens"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.sp.EncodeAsTokens(processed_text)
return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id):
"""convert Id to sentencpiece token"""
return self.sp.IdToPiece(Id)
def TokenToId(self, token):
"""convert sentencpiece token to Id"""
return self.sp.PieceToId(token)
def DecodeIds(self, Ids):
"""converts ids to a text string"""
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return self.sp.DecodeIds(Ids)
def DecodeTokens(self, Tokens):
"""converts sentencepiece tokens to a text string"""
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.sp.DecodeTokens(Tokens)
class BertWordPieceTokenizer(Tokenizer):
"""
Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
in BERT training. Default to bert-large-uncased tokenizer.
"""
def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs):
# default to bert-large-uncased tokenizer
if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
tokenizer_model_type = 'bert-large-uncased'
if torch.distributed.get_rank() == 0:
print(
'loading BertWordPieceTokenizer (',
tokenizer_model_type,
') from cache_dir ',
cache_dir)
do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
self.text_tokenizer = BertTokenizer.from_pretrained(
tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
if torch.distributed.get_rank() == 0:
print('loaded', tokenizer_model_type)
# disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12)
# set command tokens from wordpiece tokenizer values
self.num_command_tokens = 5
self.num_tokens = len(self.text_tokenizer.vocab)
self.num_text_tokens = self.num_tokens - 5
self.num_type_tokens = 2
self._command_tokens = [
CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']),
CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']),
CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']),
CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']),
]
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
# set type tokens
self.type_tokens = [
TypeToken('str0', '<str0>', 0),
TypeToken('str1', '<str1>', 1),
]
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
# parse tokens and vocabs from tokenizer
self._tokens = list(self.text_tokenizer.vocab.keys())
self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
"""convert text to wordpiece Ids"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.text_tokenizer.tokenize(processed_text)
Ids = self.text_tokenizer.convert_tokens_to_ids(tokens)
return Tokenization(Ids, processed_text, text)
def EncodeAsTokens(self, text, process_fn=None):
"""convert wordpiece token to Id"""
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = self.text_tokenizer.tokenize(processed_text)
return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
"""convert Id to sentencpiece token"""
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
return self.text_tokenizer.ids_to_tokens[Id]
def TokenToId(self, token, type_token=False):
"""convert sentencpiece token to Id"""
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
return self.text_tokenizer.vocab[token]
def DecodeIds(self, Ids, type_token=False):
"""converts ids to wordpiece tokens and joins them as a text string"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
Tokens = []
for Id in Ids:
Tokens.append(self.text_tokenizer.ids_to_tokens[Id] if Id != -1 else '-1')
Tokens = self.text_tokenizer.convert_ids_to_tokens(Ids)
return ' '.join(Tokens)
def DecodeTokens(self, Tokens, type_token=False):
"""converts wordpiece tokens to a text string"""
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir)
# disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12)
self.num_command_tokens = 2
self.num_tokens = len(self.text_tokenizer.encoder)
self.num_text_tokens = self.num_tokens - 1
self.num_type_tokens = 2
self._command_tokens = [
CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
]
self.command_name_map = {tok.name: tok for tok in self._command_tokens}
self.command_token_map = {tok.token: tok for tok in self._command_tokens}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
self.type_tokens = [
TypeToken('str0', '<str0>', 0),
TypeToken('str1', '<str1>', 1),
]
self.type_name_map = {tok.name: tok for tok in self.type_tokens}
self.type_token_map = {tok.token: tok for tok in self.type_tokens}
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
self._tokens = list(self.text_tokenizer.encoder.keys())
self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
Ids = self.text_tokenizer.encode(processed_text)
# return Tokenization(Ids, processed_text, text)
tokenization = Tokenization(Ids, processed_text, text)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
tokens = []
for token in re.findall(self.text_tokenizer.pat, processed_text):
token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
tokenization = Tokenization(tokens, processed_text, text, asIds=False)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
# return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
if isinstance(Id, (TypeToken, CommandToken)):
return Id.token
if type_token:
return self.type_id_map[Id].token
return self.text_tokenizer.decoder[Id]
def TokenToId(self, token, type_token=False):
if isinstance(token, (TypeToken, CommandToken)):
return token.Id
if type_token:
return self.type_token_map[token].Id
return self.text_tokenizer.encoder[token]
def DecodeIds(self, Ids, type_token=False):
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return self.text_tokenizer.decode(Ids)
def DecodeTokens(self, Tokens, type_token=False):
if type_token:
return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
if isinstance(Tokens, Tokenization):
Tokens = Tokens.tokenization
return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import os
import regex as re
from io import open
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py"""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging
import os
import unicodedata
from io import open
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512,
'bert-large-uncased': 512,
'bert-base-cased': 512,
'bert-large-cased': 512,
'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512,
}
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BertTokenizer.
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
"""
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text):
if self.do_basic_tokenize:
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_file = pretrained_model_name_or_path
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self,
do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fp16util import (
BN_convert_float,
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
from .fp16 import *
from .loss_scaler import *
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stable version of apex FP16 Optimizer"""
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16_Module(MegatronModule):
def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
and changing the call to ``backward``.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# Name the FP16_Optimizer instance to replace the existing optimizer
# (recommended but not required):
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.
Args:
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
``init_optimizer`` is expected to have been constructed in the ordinary way.
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
named to replace ``init_optimizer``, for two reasons:
First, it means that references to the same name
later in the file will not have to change.
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
modify ``init_optimizer``. If you do choose a unique name for the new
:class:`FP16_Optimizer` instance, you should only work with this new instance,
because the preexisting optimizer might no longer behave as expected.
``init_optimizer`` may be any Pytorch optimizer.
It may contain a mixture of fp16 and fp32 parameters organized into any number of
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
ingest these ``param_groups`` and remember them.
Calls to ::
loss.backward()
must be replaced with ::
optimizer.backward(loss)
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
loss scaling and copies to master gradients.
.. note::
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
are downscaled before being applied. This means that adjusting the loss scale, or using
dynamic loss scaling, should not require retuning the learning rate or any other
hyperparameters.
**Advanced options**
**Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
See docstring for :attr:`step`.
**Gradient clipping**: Use :attr:`clip_master_grads`.
**Multiple losses**: If your model accumulates gradients from multiple losses,
this can be made more efficient by supplying ``update_master_grads=False``
to :attr:`backward`. See docstring for :attr:`backward`.
**Manually adjusting loss scale**: The current loss scale can be retrieved or set via ::
print(optimizer.loss_scale)
optimizer.loss_scale = new_loss_scale
For static loss scaling, manually adjusting the loss scale over time is a reasonable
thing to do. During later epochs, gradients may become smaller, and a
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
the loss scale is not recommended.
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
should still work as intended.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=False):
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.verbose = verbose
self.optimizer = init_optimizer
# init_state_dict sets up an alternative way to cast per-param state tensors.
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
# init_state_dict = init_optimizer.state_dict()
self.fp16_groups = []
self.fp32_from_fp16_groups = []
self.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.optimizer.param_groups):
self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
# Copythe model parallel flag.
master_param.model_parallel = param.model_parallel
param_group['params'][i] = master_param
fp32_from_fp16_params_this_group.append(master_param)
# Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32.
if param in self.optimizer.state:
self.optimizer.state[master_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)
# Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# alternative way to cast per-param state tensors:
# self.optimizer.load_state_dict(init_state_dict)
if dynamic_loss_scale:
self.dynamic_loss_scale = True
if dynamic_loss_args is not None:
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
else:
self.loss_scaler = DynamicLossScaler()
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(static_loss_scale)
self.overflow = False
self.first_closure_call_this_step = True
self.clip_grad_norm = clip_grad_norm
def maybe_print(self, msg):
if self.verbose:
print(msg)
def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
def __setstate__(self, state):
raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
def zero_grad(self, set_grads_to_None=False):
"""
Zero fp32 and fp16 parameter grads.
"""
# In principle, only the .grad attributes of the model params need to be zeroed,
# because gradients are copied into the FP32 master params. However, we zero
# all gradients owned by the optimizer, just to be safe:
for group in self.optimizer.param_groups:
for p in group['params']:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
# Zero fp16 gradients owned by the model:
for fp16_group in self.fp16_groups:
for param in fp16_group:
if set_grads_to_None:
param.grad = None
else:
if param.grad is not None:
param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_()
def _check_overflow(self):
params = []
for group in self.fp16_groups:
for param in group:
params.append(param)
for group in self.fp32_from_fp32_groups:
for param in group:
params.append(param)
self.overflow = self.loss_scaler.has_overflow(params)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
def _master_params_to_model_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp16_group, fp32_from_fp16_group)
def _model_params_to_master_params(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
grads = [p.grad for p in group['params'] if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2):
"""
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the current fp32 gradients (viewed as a single vector).
.. warning::
Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
"""
if not self.overflow:
fp32_params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
return self.clip_grad_norm(fp32_params, max_norm, norm_type)
else:
return -1
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(
self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
def step(self, closure=None): # could add clip option.
"""
If no closure is supplied, :attr:`step` should be called after
``fp16_optimizer_obj.backward(loss)``.
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
another forward pass using their model.
If a closure is supplied, :attr:`step` may be called without a prior call to
:attr:`backward(loss)`.
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
However, the user should take care that any ``loss.backward()`` call within the closure
has been replaced by ``fp16_optimizer_obj.backward(loss)``.
Args:
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.
Example with closure::
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer.
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# loss.backward() becomes:
optimizer.backward(loss)
return loss
optimizer.step(closure)
.. warning::
Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.
.. _`ordinary Pytorch optimizer use`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
"""
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}"
.format(scale, self.loss_scale))
return
if closure is not None:
retval = self._step_with_closure(closure)
else:
retval = self.optimizer.step()
self._master_params_to_model_params()
return retval
def _step_with_closure(self, closure):
def wrapped_closure():
# helpful for debugging
# print("Calling wrapped_closure, first_closure_call_this_step = {}"
# .format(self.first_closure_call_this_step))
if self.first_closure_call_this_step:
# We expect that the fp16 params are initially fresh on entering self.step(),
# so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
# is called within self.optimizer.step().
self.first_closure_call_this_step = False
else:
# If self.optimizer.step() internally calls wrapped_closure more than once,
# it may update the fp32 params after each call. However, self.optimizer
# doesn't know about the fp16 params at all. If the fp32 params get updated,
# we can't rely on self.optimizer to refresh the fp16 params. We need
# to handle that manually:
self._master_params_to_model_params()
# Our API expects the user to give us ownership of the backward() call by
# replacing all calls to loss.backward() with optimizer.backward(loss).
# This requirement holds whether or not the call to backward() is made within a closure.
# If the user is properly calling optimizer.backward(loss) within "closure,"
# calling closure() here will give the fp32 master params fresh gradients
# for the optimizer to play with, so all wrapped_closure needs to do is call
# closure() and return the loss.
temp_loss = closure()
while(self.overflow):
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale))
temp_loss = closure()
return temp_loss
retval = self.optimizer.step(wrapped_closure)
self.first_closure_call_this_step = True
return retval
def backward(self, loss, update_master_grads=True, retain_graph=False):
"""
:attr:`backward` performs the following conceptual steps:
1. fp32_loss = loss.float() (see first Note below)
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
5. Finally, master grads are divided by loss_scale.
In this way, after :attr:`backward`, the master params have fresh gradients,
and :attr:`step` may be called.
.. note::
:attr:`backward` internally converts the loss to fp32 before applying the loss scale.
This provides some additional safety against overflow if the user has supplied an
fp16 loss value.
However, for maximum overflow safety, the user should
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
:attr:`backward`.
.. warning::
The gradients found in a model's leaves after the call to
:attr:`backward` should not be regarded as valid in general,
because it's possible
they have been scaled (and in the case of dynamic loss scaling,
the scale factor may change over time).
If the user wants to inspect gradients after a call to :attr:`backward`,
only the master gradients should be regarded as valid. These can be retrieved via
:attr:`inspect_master_grad_data()`.
Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
Example::
# Ordinary operation:
optimizer.backward(loss)
# Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call, but
# the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1)
optimizer.backward(loss2)
# More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated.
optimizer.backward(loss1, update_master_grads=False)
optimizer.backward(loss2, update_master_grads=False)
optimizer.update_master_grads()
"""
# To consider: try multiple backward passes using retain_grad=True to find
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if update_master_grads:
self.update_master_grads()
def update_master_grads(self):
"""
Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly
updated by the optimizer. :attr:`update_master_grads` only needs to be called if
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow:
return
self._model_grads_to_master_grads()
self._downscale_master()
def inspect_master_grad_data(self):
"""
When running with :class:`FP16_Optimizer`,
``.grad`` attributes of a model's fp16 leaves should not be
regarded as truthful, because they might be scaled.
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
the fp32 master params' ``.grad``
attributes will contain valid gradients properly divided by the loss scale. However,
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
nonintuitive. :attr:`inspect_master_grad_data`
allows those gradients to be viewed with shapes corresponding to their associated model leaves.
Returns:
List of lists (one list for each parameter group). The list for each parameter group
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
"""
if self.overflow:
print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
"Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
return None
else:
# The optimizer owns only references to master params.
master_grads_data = []
for param_group in self.optimizer.param_groups:
master_grads_this_group = []
for param in param_group['params']:
if param.grad is not None:
master_grads_this_group.append(param.grad.data)
else:
master_grads_this_group.append(None)
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment