Commit 230156c4 authored by yangzhong's avatar yangzhong
Browse files

bert-large training

parents
Pipeline #3006 failed with stages
in 0 seconds
# Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#1 DGX1 phase1
bert--DGX1:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "1"
BATCHSIZE: "8192"
LR: "6e-3"
GRADIENT_STEPS: "512"
PHASE: "1"
#4 DGX1 phase1
bert--DGX1_4x8x16x128:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "4"
BATCHSIZE: "2048"
LR: "6e-3"
GRADIENT_STEPS: "128"
PHASE: "1"
#16 DGX1 phase1
bert--DGX1_16x8x16x32:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "16"
BATCHSIZE: "512"
LR: "6e-3"
GRADIENT_STEPS: "32"
PHASE: "1"
#1 DGX2 phase1
bert--DGX2:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "1"
BATCHSIZE: "4096"
LR: "6e-3"
GRADIENT_STEPS: "64"
PHASE: "1"
#4 DGX2 phase1
bert--DGX2_4x16x64x16:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "4"
BATCHSIZE: "1024"
LR: "6e-3"
GRADIENT_STEPS: "16"
PHASE: "1"
#16 DGX2 phase1
bert--DGX2_16x16x64x4:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "16"
BATCHSIZE: "256"
LR: "6e-3"
GRADIENT_STEPS: "4"
PHASE: "1"
#64 DGX2 phase1
bert--DGX2_64x16x64:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "64"
BATCHSIZE: "64"
LR: "6e-3"
GRADIENT_STEPS: "1"
PHASE: "1"
#1 DGX1 phase2
bert--DGX1_1x8x4x1024:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "1"
BATCHSIZE: "4096"
LR: "4e-3"
GRADIENT_STEPS: "1024"
PHASE: "2"
#4 DGX1 phase2
bert--DGX1_4x8x4x256:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "4"
BATCHSIZE: "1024"
LR: "4e-3"
GRADIENT_STEPS: "256"
PHASE: "2"
#16 DGX1 phase2
bert--DGX1_16x8x4x64:
<<: *BERT_ON_CLUSTER
<<: *DGX1
variables:
<<: *DGX1_VARS
NNODES: "16"
BATCHSIZE: "256"
LR: "4e-3"
GRADIENT_STEPS: "64"
PHASE: "2"
#1 DGX2 phase2
bert--DGX2_1x16x8x256:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "1"
BATCHSIZE: "2048"
LR: "4e-3"
GRADIENT_STEPS: "256"
PHASE: "2"
#4 DGX2 phase2
bert--DGX2_4x16x8x64:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "4"
BATCHSIZE: "512"
LR: "4e-3"
GRADIENT_STEPS: "64"
PHASE: "2"
#16 DGX2 phase2
bert--DGX2_16x16x8x16:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "16"
BATCHSIZE: "128"
LR: "4e-3"
GRADIENT_STEPS: "16"
PHASE: "2"
#64 DGX2 phase2
bert--DGX2_64x16x8x4:
<<: *BERT_ON_CLUSTER
<<: *DGX2
variables:
<<: *DGX2_VARS
NNODES: "64"
BATCHSIZE: "32"
LR: "4e-3"
GRADIENT_STEPS: "4"
PHASE: "2"
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
import os
import random
from io import open
import h5py
import numpy as np
from tqdm import tqdm, trange
from tokenization import BertTokenizer
import tokenization as tokenization
import random
import collections
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_random_next = is_random_next
self.masked_lm_positions = masked_lm_positions
self.masked_lm_labels = masked_lm_labels
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_example_file(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_file):
"""Create TF example files from `TrainingInstance`s."""
total_written = 0
features = collections.OrderedDict()
num_instances = len(instances)
features["input_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32")
features["input_mask"] = np.zeros([num_instances, max_seq_length], dtype="int32")
features["segment_ids"] = np.zeros([num_instances, max_seq_length], dtype="int32")
features["masked_lm_positions"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32")
features["masked_lm_ids"] = np.zeros([num_instances, max_predictions_per_seq], dtype="int32")
features["next_sentence_labels"] = np.zeros(num_instances, dtype="int32")
for inst_index, instance in enumerate(tqdm(instances)):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
features["input_ids"][inst_index] = input_ids
features["input_mask"][inst_index] = input_mask
features["segment_ids"][inst_index] = segment_ids
features["masked_lm_positions"][inst_index] = masked_lm_positions
features["masked_lm_ids"][inst_index] = masked_lm_ids
features["next_sentence_labels"][inst_index] = next_sentence_label
total_written += 1
# if inst_index < 20:
# tf.logging.info("*** Example ***")
# tf.logging.info("tokens: %s" % " ".join(
# [tokenization.printable_text(x) for x in instance.tokens]))
# for feature_name in features.keys():
# feature = features[feature_name]
# values = []
# if feature.int64_list.value:
# values = feature.int64_list.value
# elif feature.float_list.value:
# values = feature.float_list.value
# tf.logging.info(
# "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
print("saving data")
f= h5py.File(output_file, 'w')
f.create_dataset("input_ids", data=features["input_ids"], dtype='i4', compression='gzip')
f.create_dataset("input_mask", data=features["input_mask"], dtype='i1', compression='gzip')
f.create_dataset("segment_ids", data=features["segment_ids"], dtype='i1', compression='gzip')
f.create_dataset("masked_lm_positions", data=features["masked_lm_positions"], dtype='i4', compression='gzip')
f.create_dataset("masked_lm_ids", data=features["masked_lm_ids"], dtype='i4', compression='gzip')
f.create_dataset("next_sentence_labels", data=features["next_sentence_labels"], dtype='i1', compression='gzip')
f.flush()
f.close()
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
print("creating instance from {}".format(input_file))
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# Remove empty documents
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
rng.shuffle(instances)
return instances
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
#If picked random document is the same as the current document
if random_document_index == document_index:
is_random_next = False
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indexes.append(i)
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--vocab_file",
default=None,
type=str,
required=True,
help="The vocabulary the BERT model will train on.")
parser.add_argument("--input_file",
default=None,
type=str,
required=True,
help="The input train corpus. can be directory with .txt files or a path to a single file")
parser.add_argument("--output_file",
default=None,
type=str,
required=True,
help="The output file where the model checkpoints will be written.")
## Other parameters
# str
parser.add_argument("--bert_model", default="bert-large-uncased", type=str, required=False,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
#int
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--dupe_factor",
default=10,
type=int,
help="Number of times to duplicate the input data (with different masks).")
parser.add_argument("--max_predictions_per_seq",
default=20,
type=int,
help="Maximum sequence length.")
# floats
parser.add_argument("--masked_lm_prob",
default=0.15,
type=float,
help="Masked LM probability.")
parser.add_argument("--short_seq_prob",
default=0.1,
type=float,
help="Probability to create a sequence shorter than maximum sequence length")
parser.add_argument("--do_lower_case",
action='store_true',
default=True,
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument('--random_seed',
type=int,
default=12345,
help="random seed for initialization")
args = parser.parse_args()
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512)
input_files = []
if os.path.isfile(args.input_file):
input_files.append(args.input_file)
elif os.path.isdir(args.input_file):
input_files = [os.path.join(args.input_file, f) for f in os.listdir(args.input_file) if (os.path.isfile(os.path.join(args.input_file, f)) and f.endswith('.txt') )]
else:
raise ValueError("{} is not a valid path".format(args.input_file))
rng = random.Random(args.random_seed)
instances = create_training_instances(
input_files, tokenizer, args.max_seq_length, args.dupe_factor,
args.short_seq_prob, args.masked_lm_prob, args.max_predictions_per_seq,
rng)
output_file = args.output_file
write_instance_to_example_file(instances, tokenizer, args.max_seq_length,
args.max_predictions_per_seq, output_file)
if __name__ == "__main__":
main()
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
class BooksDownloader:
def __init__(self, save_path):
self.save_path = save_path
pass
def download(self):
bookscorpus_download_command = 'python3 /workspace/bookcorpus/download_files.py --list /workspace/bookcorpus/url_list.jsonl --out'
bookscorpus_download_command += ' ' + self.save_path + '/bookscorpus'
bookscorpus_download_command += ' --trash-bad-count'
bookscorpus_download_process = subprocess.run(bookscorpus_download_command, shell=True, check=True)
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
class BookscorpusTextFormatting:
def __init__(self, books_path, output_filename, recursive = False):
self.books_path = books_path
self.recursive = recursive
self.output_filename = output_filename
# This puts one book per line
def merge(self):
with open(self.output_filename, mode='w', newline='\n') as ofile:
for filename in glob.glob(self.books_path + '/' + '*.txt', recursive=True):
with open(filename, mode='r', encoding='utf-8-sig', newline='\n') as file:
for line in file:
if line.strip() != '':
ofile.write(line.strip() + ' ')
ofile.write("\n\n")
\ No newline at end of file
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from GooglePretrainedWeightDownloader import GooglePretrainedWeightDownloader
from NVIDIAPretrainedWeightDownloader import NVIDIAPretrainedWeightDownloader
from WikiDownloader import WikiDownloader
from BooksDownloader import BooksDownloader
from GLUEDownloader import GLUEDownloader
from SquadDownloader import SquadDownloader
class Downloader:
def __init__(self, dataset_name, save_path):
self.dataset_name = dataset_name
self.save_path = save_path
def download(self):
if self.dataset_name == 'bookscorpus':
self.download_bookscorpus()
elif self.dataset_name == 'wikicorpus_en':
self.download_wikicorpus('en')
elif self.dataset_name == 'wikicorpus_zh':
self.download_wikicorpus('zh')
elif self.dataset_name == 'google_pretrained_weights':
self.download_google_pretrained_weights()
elif self.dataset_name == 'nvidia_pretrained_weights':
self.download_nvidia_pretrained_weights()
elif self.dataset_name in {'mrpc', 'sst-2'}:
self.download_glue(self.dataset_name)
elif self.dataset_name == 'squad':
self.download_squad()
elif self.dataset_name == 'all':
self.download_bookscorpus()
self.download_wikicorpus('en')
self.download_wikicorpus('zh')
self.download_google_pretrained_weights()
self.download_nvidia_pretrained_weights()
self.download_glue('mrpc')
self.download_glue('sst-2')
self.download_squad()
else:
print(self.dataset_name)
assert False, 'Unknown dataset_name provided to downloader'
def download_bookscorpus(self):
downloader = BooksDownloader(self.save_path)
downloader.download()
def download_wikicorpus(self, language):
downloader = WikiDownloader(language, self.save_path)
downloader.download()
def download_google_pretrained_weights(self):
downloader = GooglePretrainedWeightDownloader(self.save_path)
downloader.download()
def download_nvidia_pretrained_weights(self):
downloader = NVIDIAPretrainedWeightDownloader(self.save_path)
downloader.download()
def download_glue(self, task_name):
downloader = GLUEDownloader(self.save_path)
downloader.download(task_name)
def download_squad(self):
downloader = SquadDownloader(self.save_path)
downloader.download()
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import wget
from pathlib import Path
def mkdir(path):
Path(path).mkdir(parents=True, exist_ok=True)
class GLUEDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/glue'
def download(self, task_name):
mkdir(self.save_path)
if task_name in {'mrpc', 'mnli'}:
task_name = task_name.upper()
elif task_name == 'cola':
task_name = 'CoLA'
else: # SST-2
assert task_name == 'sst-2'
task_name = 'SST'
wget.download(
'https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py',
out=self.save_path,
)
sys.path.append(self.save_path)
import download_glue_data
download_glue_data.main(
['--data_dir', self.save_path, '--tasks', task_name])
sys.path.pop()
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import urllib.request
import zipfile
class GooglePretrainedWeightDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/google_pretrained_weights'
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
# Download urls
self.model_urls = {
'bert_base_uncased': ('https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip', 'uncased_L-12_H-768_A-12.zip'),
'bert_large_uncased': ('https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip', 'uncased_L-24_H-1024_A-16.zip'),
'bert_base_cased': ('https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip', 'cased_L-12_H-768_A-12.zip'),
'bert_large_cased': ('https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip', 'cased_L-24_H-1024_A-16.zip'),
'bert_base_multilingual_cased': ('https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip', 'multi_cased_L-12_H-768_A-12.zip'),
'bert_large_multilingual_uncased': ('https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip', 'multilingual_L-12_H-768_A-12.zip'),
'bert_base_chinese': ('https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip', 'chinese_L-12_H-768_A-12.zip')
}
# SHA256sum verification for file download integrity (and checking for changes from the download source over time)
self.bert_base_uncased_sha = {
'bert_config.json': '7b4e5f53efbd058c67cda0aacfafb340113ea1b5797d9ce6ee411704ba21fcbc',
'bert_model.ckpt.data-00000-of-00001': '58580dc5e0bf0ae0d2efd51d0e8272b2f808857f0a43a88aaf7549da6d7a8a84',
'bert_model.ckpt.index': '04c1323086e2f1c5b7c0759d8d3e484afbb0ab45f51793daab9f647113a0117b',
'bert_model.ckpt.meta': 'dd5682170a10c3ea0280c2e9b9a45fee894eb62da649bbdea37b38b0ded5f60e',
'vocab.txt': '07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3',
}
self.bert_large_uncased_sha = {
'bert_config.json': 'bfa42236d269e2aeb3a6d30412a33d15dbe8ea597e2b01dc9518c63cc6efafcb',
'bert_model.ckpt.data-00000-of-00001': 'bc6b3363e3be458c99ecf64b7f472d2b7c67534fd8f564c0556a678f90f4eea1',
'bert_model.ckpt.index': '68b52f2205ffc64dc627d1120cf399c1ef1cbc35ea5021d1afc889ffe2ce2093',
'bert_model.ckpt.meta': '6fcce8ff7628f229a885a593625e3d5ff9687542d5ef128d9beb1b0c05edc4a1',
'vocab.txt': '07eced375cec144d27c900241f3e339478dec958f92fddbc551f295c992038a3',
}
self.bert_base_cased_sha = {
'bert_config.json': 'f11dfb757bea16339a33e1bf327b0aade6e57fd9c29dc6b84f7ddb20682f48bc',
'bert_model.ckpt.data-00000-of-00001': '734d5a1b68bf98d4e9cb6b6692725d00842a1937af73902e51776905d8f760ea',
'bert_model.ckpt.index': '517d6ef5c41fc2ca1f595276d6fccf5521810d57f5a74e32616151557790f7b1',
'bert_model.ckpt.meta': '5f8a9771ff25dadd61582abb4e3a748215a10a6b55947cbb66d0f0ba1694be98',
'vocab.txt': 'eeaa9875b23b04b4c54ef759d03db9d1ba1554838f8fb26c5d96fa551df93d02',
}
self.bert_large_cased_sha = {
'bert_config.json': '7adb2125c8225da495656c982fd1c5f64ba8f20ad020838571a3f8a954c2df57',
'bert_model.ckpt.data-00000-of-00001': '6ff33640f40d472f7a16af0c17b1179ca9dcc0373155fb05335b6a4dd1657ef0',
'bert_model.ckpt.index': 'ef42a53f577fbe07381f4161b13c7cab4f4fc3b167cec6a9ae382c53d18049cf',
'bert_model.ckpt.meta': 'd2ddff3ed33b80091eac95171e94149736ea74eb645e575d942ec4a5e01a40a1',
'vocab.txt': 'eeaa9875b23b04b4c54ef759d03db9d1ba1554838f8fb26c5d96fa551df93d02',
}
self.bert_base_multilingual_cased_sha = {
'bert_config.json': 'e76c3964bc14a8bb37a5530cdc802699d2f4a6fddfab0611e153aa2528f234f0',
'bert_model.ckpt.data-00000-of-00001': '55b8a2df41f69c60c5180e50a7c31b7cdf6238909390c4ddf05fbc0d37aa1ac5',
'bert_model.ckpt.index': '7d8509c2a62b4e300feb55f8e5f1eef41638f4998dd4d887736f42d4f6a34b37',
'bert_model.ckpt.meta': '95e5f1997e8831f1c31e5cf530f1a2e99f121e9cd20887f2dce6fe9e3343e3fa',
'vocab.txt': 'fe0fda7c425b48c516fc8f160d594c8022a0808447475c1a7c6d6479763f310c',
}
self.bert_large_multilingual_uncased_sha = {
'bert_config.json': '49063bb061390211d2fdd108cada1ed86faa5f90b80c8f6fdddf406afa4c4624',
'bert_model.ckpt.data-00000-of-00001': '3cd83912ebeb0efe2abf35c9f1d5a515d8e80295e61c49b75c8853f756658429',
'bert_model.ckpt.index': '87c372c1a3b1dc7effaaa9103c80a81b3cbab04c7933ced224eec3b8ad2cc8e7',
'bert_model.ckpt.meta': '27f504f34f02acaa6b0f60d65195ec3e3f9505ac14601c6a32b421d0c8413a29',
'vocab.txt': '87b44292b452f6c05afa49b2e488e7eedf79ea4f4c39db6f2f4b37764228ef3f',
}
self.bert_base_chinese_sha = {
'bert_config.json': '7aaad0335058e2640bcb2c2e9a932b1cd9da200c46ea7b8957d54431f201c015',
'bert_model.ckpt.data-00000-of-00001': '756699356b78ad0ef1ca9ba6528297bcb3dd1aef5feadd31f4775d7c7fc989ba',
'bert_model.ckpt.index': '46315546e05ce62327b3e2cd1bed22836adcb2ff29735ec87721396edb21b82e',
'bert_model.ckpt.meta': 'c0f8d51e1ab986604bc2b25d6ec0af7fd21ff94cf67081996ec3f3bf5d823047',
'vocab.txt': '45bbac6b341c319adc98a532532882e91a9cefc0329aa57bac9ae761c27b291c',
}
# Relate SHA to urls for loop below
self.model_sha = {
'bert_base_uncased': self.bert_base_uncased_sha,
'bert_large_uncased': self.bert_large_uncased_sha,
'bert_base_cased': self.bert_base_cased_sha,
'bert_large_cased': self.bert_large_cased_sha,
'bert_base_multilingual_cased': self.bert_base_multilingual_cased_sha,
'bert_large_multilingual_uncased': self.bert_large_multilingual_uncased_sha,
'bert_base_chinese': self.bert_base_chinese_sha
}
# Helper to get sha256sum of a file
def sha256sum(self, filename):
h = hashlib.sha256()
b = bytearray(128*1024)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
for n in iter(lambda : f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()
def download(self):
# Iterate over urls: download, unzip, verify sha256sum
found_mismatch_sha = False
for model in self.model_urls:
url = self.model_urls[model][0]
file = self.save_path + '/' + self.model_urls[model][1]
print('Downloading', url)
response = urllib.request.urlopen(url)
with open(file, 'wb') as handle:
handle.write(response.read())
print('Unzipping', file)
zip = zipfile.ZipFile(file, 'r')
zip.extractall(self.save_path)
zip.close()
sha_dict = self.model_sha[model]
for extracted_file in sha_dict:
sha = sha_dict[extracted_file]
if sha != self.sha256sum(file[:-4] + '/' + extracted_file):
found_mismatch_sha = True
print('SHA256sum does not match on file:', extracted_file, 'from download url:', url)
else:
print(file[:-4] + '/' + extracted_file, '\t', 'verified')
if not found_mismatch_sha:
print("All downloads pass sha256sum verification.")
def serialize(self):
pass
def deserialize(self):
pass
def listAvailableWeights(self):
print("Available Weight Datasets")
for item in self.model_urls:
print(item)
def listLocallyStoredWeights(self):
pass
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
class NVIDIAPretrainedWeightDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/nvidia_pretrained_weights'
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
pass
def download(self):
assert False, 'NVIDIAPretrainedWeightDownloader not implemented yet.'
\ No newline at end of file
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bz2
import os
import urllib.request
import sys
class SquadDownloader:
def __init__(self, save_path):
self.save_path = save_path + '/squad'
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
if not os.path.exists(self.save_path + '/v1.1'):
os.makedirs(self.save_path + '/v1.1')
if not os.path.exists(self.save_path + '/v2.0'):
os.makedirs(self.save_path + '/v2.0')
self.download_urls = {
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json' : 'v1.1/train-v1.1.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json' : 'v1.1/dev-v1.1.json',
'https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/' : 'v1.1/evaluate-v1.1.py',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json' : 'v2.0/train-v2.0.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json' : 'v2.0/dev-v2.0.json',
'https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/' : 'v2.0/evaluate-v2.0.py',
}
def download(self):
for item in self.download_urls:
url = item
file = self.download_urls[item]
print('Downloading:', url)
if os.path.isfile(self.save_path + '/' + file):
print('** Download file already exists, skipping download')
else:
response = urllib.request.urlopen(url)
with open(self.save_path + '/' + file, "wb") as handle:
handle.write(response.read())
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from itertools import islice
import multiprocessing
import statistics
class Sharding:
def __init__(self, input_files, output_name_prefix, n_training_shards, n_test_shards, fraction_test_set):
assert len(input_files) > 0, 'The input file list must contain at least one file.'
assert n_training_shards > 0, 'There must be at least one output shard.'
assert n_test_shards > 0, 'There must be at least one output shard.'
self.n_training_shards = n_training_shards
self.n_test_shards = n_test_shards
self.fraction_test_set = fraction_test_set
self.input_files = input_files
self.output_name_prefix = output_name_prefix
self.output_training_identifier = '_training'
self.output_test_identifier = '_test'
self.output_file_extension = '.txt'
self.articles = {} # key: integer identifier, value: list of articles
self.sentences = {} # key: integer identifier, value: list of sentences
self.output_training_files = {} # key: filename, value: list of articles to go into file
self.output_test_files = {} # key: filename, value: list of articles to go into file
self.init_output_files()
# Remember, the input files contain one article per line (the whitespace check is to skip extraneous blank lines)
def load_articles(self):
print('Start: Loading Articles')
global_article_count = 0
for input_file in self.input_files:
print('input file:', input_file)
with open(input_file, mode='r', newline='\n') as f:
for i, line in enumerate(f):
if line.strip():
self.articles[global_article_count] = line.rstrip()
global_article_count += 1
print('End: Loading Articles: There are', len(self.articles), 'articles.')
def segment_articles_into_sentences(self, segmenter):
print('Start: Sentence Segmentation')
if len(self.articles) is 0:
self.load_articles()
assert len(self.articles) is not 0, 'Please check that input files are present and contain data.'
# TODO: WIP: multiprocessing (create independent ranges and spawn processes)
use_multiprocessing = 'serial'
def chunks(data, size=len(self.articles)):
it = iter(data)
for i in range(0, len(data), size):
yield {k: data[k] for k in islice(it, size)}
if use_multiprocessing == 'manager':
manager = multiprocessing.Manager()
return_dict = manager.dict()
jobs = []
n_processes = 7 # in addition to the main process, total = n_proc+1
def work(articles, return_dict):
sentences = {}
for i, article in enumerate(articles):
sentences[i] = segmenter.segment_string(articles[article])
if i % 5000 == 0:
print('Segmenting article', i)
return_dict.update(sentences)
for item in chunks(self.articles, len(self.articles)):
p = multiprocessing.Process(target=work, args=(item, return_dict))
# Busy wait
while len(jobs) >= n_processes:
pass
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
elif use_multiprocessing == 'queue':
work_queue = multiprocessing.Queue()
jobs = []
for item in chunks(self.articles, len(self.articles)):
pass
else: # serial option
for i, article in enumerate(self.articles):
self.sentences[i] = segmenter.segment_string(self.articles[article])
if i % 5000 == 0:
print('Segmenting article', i)
print('End: Sentence Segmentation')
def init_output_files(self):
print('Start: Init Output Files')
assert len(self.output_training_files) is 0, 'Internal storage self.output_files already contains data. This function is intended to be used by the constructor only.'
assert len(self.output_test_files) is 0, 'Internal storage self.output_files already contains data. This function is intended to be used by the constructor only.'
for i in range(self.n_training_shards):
name = self.output_name_prefix + self.output_training_identifier + '_' + str(i) + self.output_file_extension
self.output_training_files[name] = []
for i in range(self.n_test_shards):
name = self.output_name_prefix + self.output_test_identifier + '_' + str(i) + self.output_file_extension
self.output_test_files[name] = []
print('End: Init Output Files')
def get_sentences_per_shard(self, shard):
result = 0
for article_id in shard:
result += len(self.sentences[article_id])
return result
def distribute_articles_over_shards(self):
print('Start: Distribute Articles Over Shards')
assert len(self.articles) >= self.n_training_shards + self.n_test_shards, 'There are fewer articles than shards. Please add more data or reduce the number of shards requested.'
# Create dictionary with - key: sentence count per article, value: article id number
sentence_counts = defaultdict(lambda: [])
max_sentences = 0
total_sentences = 0
for article_id in self.sentences:
current_length = len(self.sentences[article_id])
sentence_counts[current_length].append(article_id)
max_sentences = max(max_sentences, current_length)
total_sentences += current_length
n_sentences_assigned_to_training = int((1 - self.fraction_test_set) * total_sentences)
nominal_sentences_per_training_shard = n_sentences_assigned_to_training // self.n_training_shards
nominal_sentences_per_test_shard = (total_sentences - n_sentences_assigned_to_training) // self.n_test_shards
consumed_article_set = set({})
unused_article_set = set(self.articles.keys())
# Make first pass and add one article worth of lines per file
for file in self.output_training_files:
current_article_id = sentence_counts[max_sentences][-1]
sentence_counts[max_sentences].pop(-1)
self.output_training_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
# Maintain the max sentence count
while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
max_sentences -= 1
if len(self.sentences[current_article_id]) > nominal_sentences_per_training_shard:
nominal_sentences_per_training_shard = len(self.sentences[current_article_id])
print('Warning: A single article contains more than the nominal number of sentences per training shard.')
for file in self.output_test_files:
current_article_id = sentence_counts[max_sentences][-1]
sentence_counts[max_sentences].pop(-1)
self.output_test_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
# Maintain the max sentence count
while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
max_sentences -= 1
if len(self.sentences[current_article_id]) > nominal_sentences_per_test_shard:
nominal_sentences_per_test_shard = len(self.sentences[current_article_id])
print('Warning: A single article contains more than the nominal number of sentences per test shard.')
training_counts = []
test_counts = []
for shard in self.output_training_files:
training_counts.append(self.get_sentences_per_shard(self.output_training_files[shard]))
for shard in self.output_test_files:
test_counts.append(self.get_sentences_per_shard(self.output_test_files[shard]))
training_median = statistics.median(training_counts)
test_median = statistics.median(test_counts)
# Make subsequent passes over files to find articles to add without going over limit
history_remaining = []
n_history_remaining = 4
while len(consumed_article_set) < len(self.articles):
for fidx, file in enumerate(self.output_training_files):
nominal_next_article_size = min(nominal_sentences_per_training_shard - training_counts[fidx], max_sentences)
# Maintain the max sentence count
while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
max_sentences -= 1
while len(sentence_counts[nominal_next_article_size]) == 0 and nominal_next_article_size > 0:
nominal_next_article_size -= 1
if nominal_next_article_size not in sentence_counts or nominal_next_article_size is 0 or training_counts[fidx] > training_median:
continue # skip adding to this file, will come back later if no file can accept unused articles
current_article_id = sentence_counts[nominal_next_article_size][-1]
sentence_counts[nominal_next_article_size].pop(-1)
self.output_training_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
for fidx, file in enumerate(self.output_test_files):
nominal_next_article_size = min(nominal_sentences_per_test_shard - test_counts[fidx], max_sentences)
# Maintain the max sentence count
while len(sentence_counts[max_sentences]) == 0 and max_sentences > 0:
max_sentences -= 1
while len(sentence_counts[nominal_next_article_size]) == 0 and nominal_next_article_size > 0:
nominal_next_article_size -= 1
if nominal_next_article_size not in sentence_counts or nominal_next_article_size is 0 or test_counts[fidx] > test_median:
continue # skip adding to this file, will come back later if no file can accept unused articles
current_article_id = sentence_counts[nominal_next_article_size][-1]
sentence_counts[nominal_next_article_size].pop(-1)
self.output_test_files[file].append(current_article_id)
consumed_article_set.add(current_article_id)
unused_article_set.remove(current_article_id)
# If unable to place articles a few times, bump up nominal sizes by fraction until articles get placed
if len(history_remaining) == n_history_remaining:
history_remaining.pop(0)
history_remaining.append(len(unused_article_set))
history_same = True
for i in range(1, len(history_remaining)):
history_same = history_same and (history_remaining[i-1] == history_remaining[i])
if history_same:
nominal_sentences_per_training_shard += 1
# nominal_sentences_per_test_shard += 1
training_counts = []
test_counts = []
for shard in self.output_training_files:
training_counts.append(self.get_sentences_per_shard(self.output_training_files[shard]))
for shard in self.output_test_files:
test_counts.append(self.get_sentences_per_shard(self.output_test_files[shard]))
training_median = statistics.median(training_counts)
test_median = statistics.median(test_counts)
print('Distributing data over shards:', len(unused_article_set), 'articles remaining.')
if len(unused_article_set) != 0:
print('Warning: Some articles did not make it into output files.')
for shard in self.output_training_files:
print('Training shard:', self.get_sentences_per_shard(self.output_training_files[shard]))
for shard in self.output_test_files:
print('Test shard:', self.get_sentences_per_shard(self.output_test_files[shard]))
print('End: Distribute Articles Over Shards')
def write_shards_to_disk(self):
print('Start: Write Shards to Disk')
for shard in self.output_training_files:
self.write_single_shard(shard, self.output_training_files[shard])
for shard in self.output_test_files:
self.write_single_shard(shard, self.output_test_files[shard])
print('End: Write Shards to Disk')
def write_single_shard(self, shard_name, shard):
with open(shard_name, mode='w', newline='\n') as f:
for article_id in shard:
for line in self.sentences[article_id]:
f.write(line + '\n')
f.write('\n') # Line break between articles
import nltk
nltk.download('punkt')
class NLTKSegmenter:
def __init(self):
pass
def segment_string(self, article):
return nltk.tokenize.sent_tokenize(article)
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bz2
import os
import urllib.request
import subprocess
import sys
class WikiDownloader:
def __init__(self, language, save_path):
self.save_path = save_path + '/wikicorpus_' + language
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
self.language = language
self.download_urls = {
'en' : 'https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2',
'zh' : 'https://dumps.wikimedia.org/zhwiki/latest/zhwiki-latest-pages-articles.xml.bz2'
}
self.output_files = {
'en' : 'wikicorpus_en.xml.bz2',
'zh' : 'wikicorpus_zh.xml.bz2'
}
def download(self):
if self.language in self.download_urls:
url = self.download_urls[self.language]
filename = self.output_files[self.language]
print('Downloading:', url)
if os.path.isfile(self.save_path + '/' + filename):
print('** Download file already exists, skipping download')
else:
response = urllib.request.urlopen(url)
with open(self.save_path + '/' + filename, "wb") as handle:
handle.write(response.read())
# Always unzipping since this is relatively fast and will overwrite
print('Unzipping:', self.output_files[self.language])
subprocess.run('bzip2 -dk ' + self.save_path + '/' + filename, shell=True, check=True)
else:
assert False, 'WikiDownloader not implemented for this language yet.'
\ No newline at end of file
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
class WikicorpusTextFormatting:
def __init__(self, wiki_path, output_filename, recursive = False):
self.wiki_path = wiki_path
self.recursive = recursive
self.output_filename = output_filename
# This puts one article per line
def merge(self):
with open(self.output_filename, mode='w', newline='\n') as ofile:
for dirname in glob.glob(self.wiki_path + '/*/', recursive=False):
for filename in glob.glob(dirname + 'wiki_*', recursive=self.recursive):
print(filename)
article_lines = []
article_open = False
with open(filename, mode='r', newline='\n') as file:
for line in file:
if '<doc id=' in line:
article_open = True
elif '</doc>' in line:
article_open = False
for oline in article_lines[1:]:
if oline != '\n':
ofile.write(oline.rstrip() + " ")
ofile.write("\n\n")
article_lines = []
else:
if article_open:
article_lines.append(line)
\ No newline at end of file
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import BookscorpusTextFormatting
import Downloader
import TextSharding
import WikicorpusTextFormatting
import argparse
import itertools
import multiprocessing
import os
import pprint
import subprocess
def main(args):
working_dir = os.environ['BERT_PREP_WORKING_DIR']
print('Working Directory:', working_dir)
print('Action:', args.action)
print('Dataset Name:', args.dataset)
if args.input_files:
args.input_files = args.input_files.split(',')
hdf5_tfrecord_folder_prefix = "_lower_case_" + str(args.do_lower_case) + "_seq_len_" + str(args.max_seq_length) \
+ "_max_pred_" + str(args.max_predictions_per_seq) + "_masked_lm_prob_" + str(args.masked_lm_prob) \
+ "_random_seed_" + str(args.random_seed) + "_dupe_factor_" + str(args.dupe_factor)
directory_structure = {
'download' : working_dir + '/download', # Downloaded and decompressed
'extracted' : working_dir +'/extracted', # Extracted from whatever the initial format is (e.g., wikiextractor)
'formatted' : working_dir + '/formatted_one_article_per_line', # This is the level where all sources should look the same
'sharded' : working_dir + '/sharded_' + "training_shards_" + str(args.n_training_shards) + "_test_shards_" + str(args.n_test_shards) + "_fraction_" + str(args.fraction_test_set),
'tfrecord' : working_dir + '/tfrecord'+ hdf5_tfrecord_folder_prefix,
'hdf5': working_dir + '/hdf5' + hdf5_tfrecord_folder_prefix
}
print('\nDirectory Structure:')
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(directory_structure)
print('')
if args.action == 'download':
if not os.path.exists(directory_structure['download']):
os.makedirs(directory_structure['download'])
downloader = Downloader.Downloader(args.dataset, directory_structure['download'])
downloader.download()
elif args.action == 'text_formatting':
assert args.dataset != 'google_pretrained_weights' and args.dataset != 'nvidia_pretrained_weights' and args.dataset != 'squad' and args.dataset != 'mrpc', 'Cannot perform text_formatting on pretrained weights'
if not os.path.exists(directory_structure['extracted']):
os.makedirs(directory_structure['extracted'])
if not os.path.exists(directory_structure['formatted']):
os.makedirs(directory_structure['formatted'])
if args.dataset == 'bookscorpus':
books_path = directory_structure['download'] + '/bookscorpus'
#books_path = directory_structure['download']
output_filename = directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt'
books_formatter = BookscorpusTextFormatting.BookscorpusTextFormatting(books_path, output_filename, recursive=True)
books_formatter.merge()
elif args.dataset == 'wikicorpus_en':
if args.skip_wikiextractor == 0:
path_to_wikiextractor_in_container = '/workspace/wikiextractor/WikiExtractor.py'
wikiextractor_command = path_to_wikiextractor_in_container + ' ' + directory_structure['download'] + '/' + args.dataset + '/wikicorpus_en.xml ' + '-b 100M --processes ' + str(args.n_processes) + ' -o ' + directory_structure['extracted'] + '/' + args.dataset
print('WikiExtractor Command:', wikiextractor_command)
wikiextractor_process = subprocess.run(wikiextractor_command, shell=True, check=True)
#wikiextractor_process.communicate()
wiki_path = directory_structure['extracted'] + '/wikicorpus_en'
output_filename = directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt'
wiki_formatter = WikicorpusTextFormatting.WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
wiki_formatter.merge()
elif args.dataset == 'wikicorpus_zh':
assert False, 'wikicorpus_zh not fully supported at this time. The simplified/tradition Chinese data needs to be translated and properly segmented still, and should work once this step is added.'
if args.skip_wikiextractor == 0:
path_to_wikiextractor_in_container = '/workspace/wikiextractor/WikiExtractor.py'
wikiextractor_command = path_to_wikiextractor_in_container + ' ' + directory_structure['download'] + '/' + args.dataset + '/wikicorpus_zh.xml ' + '-b 100M --processes ' + str(args.n_processes) + ' -o ' + directory_structure['extracted'] + '/' + args.dataset
print('WikiExtractor Command:', wikiextractor_command)
wikiextractor_process = subprocess.run(wikiextractor_command, shell=True, check=True)
#wikiextractor_process.communicate()
wiki_path = directory_structure['extracted'] + '/wikicorpus_zh'
output_filename = directory_structure['formatted'] + '/wikicorpus_zh_one_article_per_line.txt'
wiki_formatter = WikicorpusTextFormatting.WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
wiki_formatter.merge()
assert os.stat(output_filename).st_size > 0, 'File glob did not pick up extracted wiki files from WikiExtractor.'
elif args.action == 'sharding':
# Note: books+wiki requires user to provide list of input_files (comma-separated with no spaces)
if args.dataset == 'bookscorpus' or 'wikicorpus' in args.dataset or 'books_wiki' in args.dataset:
if args.input_files is None:
if args.dataset == 'bookscorpus':
args.input_files = [directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt']
elif args.dataset == 'wikicorpus_en':
args.input_files = [directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt']
elif args.dataset == 'wikicorpus_zh':
args.input_files = [directory_structure['formatted'] + '/wikicorpus_zh_one_article_per_line.txt']
elif args.dataset == 'books_wiki_en_corpus':
args.input_files = [directory_structure['formatted'] + '/bookscorpus_one_book_per_line.txt', directory_structure['formatted'] + '/wikicorpus_en_one_article_per_line.txt']
output_file_prefix = directory_structure['sharded'] + '/' + args.dataset + '/' + args.dataset
if not os.path.exists(directory_structure['sharded']):
os.makedirs(directory_structure['sharded'])
if not os.path.exists(directory_structure['sharded'] + '/' + args.dataset):
os.makedirs(directory_structure['sharded'] + '/' + args.dataset)
# Segmentation is here because all datasets look the same in one article/book/whatever per line format, and
# it seemed unnecessarily complicated to add an additional preprocessing step to call just for this.
# Different languages (e.g., Chinese simplified/traditional) may require translation and
# other packages to be called from here -- just add a conditional branch for those extra steps
segmenter = TextSharding.NLTKSegmenter()
sharding = TextSharding.Sharding(args.input_files, output_file_prefix, args.n_training_shards, args.n_test_shards, args.fraction_test_set)
sharding.load_articles()
sharding.segment_articles_into_sentences(segmenter)
sharding.distribute_articles_over_shards()
sharding.write_shards_to_disk()
else:
assert False, 'Unsupported dataset for sharding'
elif args.action == 'create_tfrecord_files':
assert False, 'TFrecord creation not supported in this PyTorch model example release.' \
''
if not os.path.exists(directory_structure['tfrecord'] + "/" + args.dataset):
os.makedirs(directory_structure['tfrecord'] + "/" + args.dataset)
def create_record_worker(filename_prefix, shard_id, output_format='tfrecord'):
bert_preprocessing_command = 'python /workspace/bert/create_pretraining_data.py'
bert_preprocessing_command += ' --input_file=' + directory_structure['sharded'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.txt'
bert_preprocessing_command += ' --output_file=' + directory_structure['tfrecord'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
bert_preprocessing_command += ' --vocab_file=' + args.vocab_file
bert_preprocessing_command += ' --do_lower_case' if args.do_lower_case else ''
bert_preprocessing_command += ' --max_seq_length=' + str(args.max_seq_length)
bert_preprocessing_command += ' --max_predictions_per_seq=' + str(args.max_predictions_per_seq)
bert_preprocessing_command += ' --masked_lm_prob=' + str(args.masked_lm_prob)
bert_preprocessing_command += ' --random_seed=' + str(args.random_seed)
bert_preprocessing_command += ' --dupe_factor=' + str(args.dupe_factor)
bert_preprocessing_process = subprocess.Popen(bert_preprocessing_command, shell=True)
last_process = bert_preprocessing_process
# This could be better optimized (fine if all take equal time)
if shard_id % args.n_processes == 0 and shard_id > 0:
bert_preprocessing_process.wait()
return last_process
output_file_prefix = args.dataset
for i in range(args.n_training_shards):
last_process =create_record_worker(output_file_prefix + '_training', i)
last_process.wait()
for i in range(args.n_test_shards):
last_process = create_record_worker(output_file_prefix + '_test', i)
last_process.wait()
elif args.action == 'create_hdf5_files':
last_process = None
if not os.path.exists(directory_structure['hdf5'] + "/" + args.dataset):
os.makedirs(directory_structure['hdf5'] + "/" + args.dataset)
def create_record_worker(filename_prefix, shard_id, output_format='hdf5'):
bert_preprocessing_command = 'python /workspace/bert/create_pretraining_data.py'
bert_preprocessing_command += ' --input_file=' + directory_structure['sharded'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.txt'
bert_preprocessing_command += ' --output_file=' + directory_structure['hdf5'] + '/' + args.dataset + '/' + filename_prefix + '_' + str(shard_id) + '.' + output_format
bert_preprocessing_command += ' --vocab_file=' + args.vocab_file
bert_preprocessing_command += ' --do_lower_case' if args.do_lower_case else ''
bert_preprocessing_command += ' --max_seq_length=' + str(args.max_seq_length)
bert_preprocessing_command += ' --max_predictions_per_seq=' + str(args.max_predictions_per_seq)
bert_preprocessing_command += ' --masked_lm_prob=' + str(args.masked_lm_prob)
bert_preprocessing_command += ' --random_seed=' + str(args.random_seed)
bert_preprocessing_command += ' --dupe_factor=' + str(args.dupe_factor)
bert_preprocessing_process = subprocess.Popen(bert_preprocessing_command, shell=True)
last_process = bert_preprocessing_process
# This could be better optimized (fine if all take equal time)
if shard_id % args.n_processes == 0 and shard_id > 0:
bert_preprocessing_process.wait()
return last_process
output_file_prefix = args.dataset
for i in range(args.n_training_shards):
last_process = create_record_worker(output_file_prefix + '_training', i)
last_process.wait()
for i in range(args.n_test_shards):
last_process = create_record_worker(output_file_prefix + '_test', i)
last_process.wait()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Preprocessing Application for Everything BERT-related'
)
parser.add_argument(
'--action',
type=str,
help='Specify the action you want the app to take. e.g., generate vocab, segment, create tfrecords',
choices={
'download', # Download and verify mdf5/sha sums
'text_formatting', # Convert into a file that contains one article/book per line
'sharding', # Convert previous formatted text into shards containing one sentence per line
'create_tfrecord_files', # Turn each shard into a TFrecord with masking and next sentence prediction info
'create_hdf5_files' # Turn each shard into a HDF5 file with masking and next sentence prediction info
}
)
parser.add_argument(
'--dataset',
type=str,
help='Specify the dataset to perform --action on',
choices={
'bookscorpus',
'wikicorpus_en',
'wikicorpus_zh',
'books_wiki_en_corpus',
'google_pretrained_weights',
'nvidia_pretrained_weights',
'mrpc',
'sst-2',
'squad',
'all'
}
)
parser.add_argument(
'--input_files',
type=str,
help='Specify the input files in a comma-separated list (no spaces)'
)
parser.add_argument(
'--n_training_shards',
type=int,
help='Specify the number of training shards to generate',
default=256
)
parser.add_argument(
'--n_test_shards',
type=int,
help='Specify the number of test shards to generate',
default=256
)
parser.add_argument(
'--fraction_test_set',
type=float,
help='Specify the fraction (0..1) of the data to withhold for the test data split (based on number of sequences)',
default=0.1
)
parser.add_argument(
'--segmentation_method',
type=str,
help='Specify your choice of sentence segmentation',
choices={
'nltk'
},
default='nltk'
)
parser.add_argument(
'--n_processes',
type=int,
help='Specify the max number of processes to allow at one time',
default=4
)
parser.add_argument(
'--random_seed',
type=int,
help='Specify the base seed to use for any random number generation',
default=12345
)
parser.add_argument(
'--dupe_factor',
type=int,
help='Specify the duplication factor',
default=5
)
parser.add_argument(
'--masked_lm_prob',
type=float,
help='Specify the probability for masked lm',
default=0.15
)
parser.add_argument(
'--max_seq_length',
type=int,
help='Specify the maximum sequence length',
default=512
)
parser.add_argument(
'--max_predictions_per_seq',
type=int,
help='Specify the maximum number of masked words per sequence',
default=20
)
parser.add_argument(
'--do_lower_case',
type=int,
help='Specify whether it is cased (0) or uncased (1) (any number greater than 0 will be treated as uncased)',
default=1
)
parser.add_argument(
'--vocab_file',
type=str,
help='Specify absolute path to vocab file to use)'
)
parser.add_argument(
'--skip_wikiextractor',
type=int,
help='Specify whether to skip wikiextractor step 0=False, 1=True',
default=0
)
parser.add_argument(
'--interactive_json_config_generator',
type=str,
help='Specify the action you want the app to take. e.g., generate vocab, segment, create tfrecords'
)
args = parser.parse_args()
main(args)
#!/bin/bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
to_download=${1:-"wiki_only"}
#Download
if [ "$to_download" = "wiki_books" ] ; then
python3 /workspace/bert/data/bertPrep.py --action download --dataset bookscorpus
fi
python3 /workspace/bert/data/bertPrep.py --action download --dataset wikicorpus_en
python3 /workspace/bert/data/bertPrep.py --action download --dataset google_pretrained_weights # Includes vocab
python3 /workspace/bert/data/bertPrep.py --action download --dataset squad
python3 /workspace/bert/data/bertPrep.py --action download --dataset mrpc
python3 /workspace/bert/data/bertPrep.py --action download --dataset sst-2
# Properly format the text files
if [ "$to_download" = "wiki_books" ] ; then
python3 /workspace/bert/data/bertPrep.py --action text_formatting --dataset bookscorpus
fi
python3 /workspace/bert/data/bertPrep.py --action text_formatting --dataset wikicorpus_en
if [ "$to_download" = "wiki_books" ] ; then
DATASET="books_wiki_en_corpus"
else
DATASET="wikicorpus_en"
# Shard the text files
fi
# Shard the text files
python3 /workspace/bert/data/bertPrep.py --action sharding --dataset $DATASET
# Create HDF5 files Phase 1
python3 /workspace/bert/data/bertPrep.py --action create_hdf5_files --dataset $DATASET --max_seq_length 128 \
--max_predictions_per_seq 20 --vocab_file $BERT_PREP_WORKING_DIR/download/google_pretrained_weights/uncased_L-24_H-1024_A-16/vocab.txt --do_lower_case 1
# Create HDF5 files Phase 2
python3 /workspace/bert/data/bertPrep.py --action create_hdf5_files --dataset $DATASET --max_seq_length 512 \
--max_predictions_per_seq 80 --vocab_file $BERT_PREP_WORKING_DIR/download/google_pretrained_weights/uncased_L-24_H-1024_A-16/vocab.txt --do_lower_case 1
#!/usr/bin/env bash
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
echo "Downloading dataset for squad..."
# Download SQuAD
v1="v1.1"
mkdir $v1
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $v1/train-v1.1.json
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $v1/dev-v1.1.json
wget https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/ -O $v1/evaluate-v1.1.py
EXP_TRAIN_v1='981b29407e0affa3b1b156f72073b945 -'
EXP_DEV_v1='3e85deb501d4e538b6bc56f786231552 -'
EXP_EVAL_v1='afb04912d18ff20696f7f88eed49bea9 -'
CALC_TRAIN_v1=`cat ${v1}/train-v1.1.json |md5sum`
CALC_DEV_v1=`cat ${v1}/dev-v1.1.json |md5sum`
CALC_EVAL_v1=`cat ${v1}/evaluate-v1.1.py |md5sum`
v2="v2.0"
mkdir $v2
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O $v2/train-v2.0.json
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O $v2/dev-v2.0.json
wget https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ -O $v2/evaluate-v2.0.py
EXP_TRAIN_v2='62108c273c268d70893182d5cf8df740 -'
EXP_DEV_v2='246adae8b7002f8679c027697b0b7cf8 -'
EXP_EVAL_v2='ff23213bed5516ea4a6d9edb6cd7d627 -'
CALC_TRAIN_v2=`cat ${v2}/train-v2.0.json |md5sum`
CALC_DEV_v2=`cat ${v2}/dev-v2.0.json |md5sum`
CALC_EVAL_v2=`cat ${v2}/evaluate-v2.0.py |md5sum`
echo "Squad data download done!"
echo "Verifying Dataset...."
if [ "$EXP_TRAIN_v1" != "$CALC_TRAIN_v1" ]; then
echo "train-v1.1.json is corrupted! md5sum doesn't match"
fi
if [ "$EXP_DEV_v1" != "$CALC_DEV_v1" ]; then
echo "dev-v1.1.json is corrupted! md5sum doesn't match"
fi
if [ "$EXP_EVAL_v1" != "$CALC_EVAL_v1" ]; then
echo "evaluate-v1.1.py is corrupted! md5sum doesn't match"
fi
if [ "$EXP_TRAIN_v2" != "$CALC_TRAIN_v2" ]; then
echo "train-v2.0.json is corrupted! md5sum doesn't match"
fi
if [ "$EXP_DEV_v2" != "$CALC_DEV_v2" ]; then
echo "dev-v2.0.json is corrupted! md5sum doesn't match"
fi
if [ "$EXP_EVAL_v2" != "$CALC_EVAL_v2" ]; then
echo "evaluate-v2.0.py is corrupted! md5sum doesn't match"
fi
echo "Complete!"
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment