Unverified Commit 203ca57a authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[example] add GPT

parent fd2c8d81
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import time
import ftfy
import numpy as np
from langdetect import detect
from tokenizer import Tokenizer
MIN_DOCUMENT_LENGTH = 128
def print_progress(prefix, start_time, num_docs, num_fixed_text, num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs):
string = prefix + ' | '
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
string += 'documents: {} | '.format(num_docs)
string += 'fixed text: {} | '.format(num_fixed_text)
string += 'non-english: {} | '.format(num_non_english_docs)
string += 'non-english chars: {} | '.format(chars_non_english_docs)
string += 'small docs: {} | '.format(num_small_docs)
string += 'small docs chars: {}'.format(chars_small_docs)
print(string, flush=True)
def filter_corpus(filename, out_filename, print_interval=10000):
print(' > filtering {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
num_docs = 0
num_written_docs = 0
num_small_docs = 0
num_fixed_text = 0
num_non_english_docs = 0
chars_non_english_docs = 0
chars_small_docs = 0
start_time = time.time()
with open(out_filename, 'wb') as f:
with open(filename, 'r') as fin:
for line in fin:
try:
num_docs += 1
myjson = json.loads(line)
# Fix text
text = ftfy.fix_text(myjson['text'])
if text != myjson['text']:
num_fixed_text += 1
myjson['text'] = text
# Detect language.
if detect(text) != 'en':
print('[non-english text]', myjson)
num_non_english_docs += 1
chars_non_english_docs += len(text)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if len(text) < (8 * MIN_DOCUMENT_LENGTH):
tokens = tokenizer.tokenize_document(text)
if len(tokens) < MIN_DOCUMENT_LENGTH:
print('[small document, skipping]:', myjson)
num_small_docs += 1
chars_small_docs += len(text)
continue
myjson = json.dumps(myjson, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
num_written_docs += 1
if num_docs % print_interval == 0:
print_progress('[PROGRESS]', start_time, num_docs, num_fixed_text, num_non_english_docs,
chars_non_english_docs, num_small_docs, chars_small_docs)
except Exception as e:
print(' skipping ', line, e)
print_progress('[FINAL]', start_time, num_docs, num_fixed_text, num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs)
if __name__ == '__main__':
print('building gpt2 dataset ...')
input_filename = sys.argv[1]
output_filename = sys.argv[2]
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
filter_corpus(input_filename, output_filename)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import argparse
import glob
import json
import multiprocessing
import os
import re
import time
from functools import partial
from pathlib import Path
import ftfy
from langdetect import detect
def process_doc(json_line, args):
# Read the line.
document = json.loads(json_line)
text = document['text']
output = {'remove_512': False, 'remove_256_javascript': False, \
'remove_512_non_english': False, 'ftfy_fix_text': False, \
'general_cleaning': False}
try:
# Remove all docs with less than 512 characters
if "remove_512" in args.tasks:
if len(text) < 512:
output['remove_512'] = True
return output, text, document, True
# Remove docs if less than 256 character length and contains Javascript
if "remove_256_javascript" in args.tasks:
if len(text) < 256 and 'javascript' in text.lower():
output['remove_256_javascript'] = True
return output, text, document, True
# Remove docs < 512 and nonenglish
if "remove_512_non_english" in args.tasks:
if len(text) < 512 and detect(text) != 'en':
output['remove_512_non_english'] = True
return output, text, document, True
# Fix the text using ftfy, don't remove the text, hence return False
if "ftfy_fix_text" in args.tasks:
fixed_text = ftfy.fix_text(text)
output['ftfy_fix_text'] = True
return output, fixed_text, document, False
# Cleaning extra spaces and newlines
if "general_cleaning" in args.tasks:
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output['general_cleaning'] = True
return output, cleaned_text, document, False
except Exception as e:
print('Error: *************************\n{}\ntext: {}'.format(e, \
text), flush=True)
return output, text, document, True
# don't remove
return output, text, document, False
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
print(' > working on {} ...'.format(input_file), flush=True)
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
= num_ftfy_fix_text = num_general_cleaning = 0
# Output file and counters.
output_cleaned = open(output_f_cleaned, 'wb')
output_filtered = open(output_f_filtered, 'wb')
start_time = time.time()
# Setup multi-processing.
num_workers = 40
fin = open(input_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
process_doc_partial = partial(process_doc, args=args)
processed_docs = pool.imap(process_doc_partial, fin, 500)
# Process documents.
for output, text, document, to_filter in processed_docs:
num_docs += 1
num_remove_512 += 1 if output['remove_512'] else 0
num_remove_java += 1 if output['remove_256_javascript'] else 0
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
else 0
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
num_general_cleaning += 1 if output['general_cleaning'] else 0
document['text'] = text
myjson = json.dumps(document, ensure_ascii=False)
if to_filter:
output_filtered.write(myjson.encode('utf-8'))
output_filtered.write('\n'.encode('utf-8'))
else:
output_cleaned.write(myjson.encode('utf-8'))
output_cleaned.write('\n'.encode('utf-8'))
if num_docs % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format(num_docs,
time.time() - start_time),
flush=True)
# Close the file.
output_cleaned.close()
output_filtered.close()
fin.close()
# Print stats.
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
format(num_docs, num_remove_512, num_remove_java,\
num_remove_512_non_english, num_ftfy_fix_text, \
num_general_cleaning), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-files', nargs = '*', required=True, default=\
None, help = 'Input json files that needs to be'\
' cleaned')
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
help = 'Tasks to perform on the input files, ' \
'such as remove_512, remove_256_javascript, ' \
'remove_512_non_english, ftfy_fix_text, and ' \
'general_cleaning. 256 or 512 means the number' \
' of characters.')
parser.add_argument('--output-path', type=str, default=None, help='Directory where the output should go')
parser.add_argument('--log-interval', type=int, default=100, help='Log interval')
args = parser.parse_args()
print('cleanup dataset ...')
for input_file in args.input_files:
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
.name)
output_f_cleaned = os.path.join(args.output_path, input_filename + \
"_cleaned" + input_filename_ext)
output_f_filtered = os.path.join(args.output_path, input_filename + \
"_filtered" + input_filename_ext)
process_set(args, input_file, output_f_cleaned, output_f_filtered)
print('done :-)', flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import json
import multiprocessing
import os
import pickle
import sys
import time
from functools import partial
import numpy as np
from lsh import cache, minhash
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
return set(text[head:head + char_ngram] for head in range(0, len(text) - char_ngram))
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b
union = set_a | set_b
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
def url_pairs_to_remove(args, bucket_urls, url_doc):
remove_urls_list = []
deduped_local, counter_local = 0, 0
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = []
main_url = items[np.random.randint(0, len(items))]
main_shingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter_local += 1
other_url = items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_shingles, other_shingles, args)
except Exception as e:
print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped_local += 1
bucket_urls.remove(other_url)
bucket_urls.remove(main_url)
if len(remove_urls) > 0:
remove_urls_list.append({main_url: remove_urls})
iteration += 1
return remove_urls_list, deduped_local, counter_local
def write_remove_urls_list(remove_urls_list, f_out):
if len(remove_urls_list) > 0:
for each_url_remove in remove_urls_list:
myjson = json.dumps(each_url_remove, ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
def compute_jaccard(each_bin, num_bins, start_time_local):
remove_urls_list = []
deduped_local, counter_local, bucket_local = 0, 0, 0
for bucket_id in each_bin:
bucket_local += 1
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
print("Counter {}, progress {:.2f} time {:.2f}".\
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
time.time() - start_time_local), flush=True)
if len(each_bin[bucket_id]) <= 1:
continue
bucket_urls = each_bin[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped_local += deduped_local_sub
counter_local += counter_local_sub
if len(remove_urls_list_sub) > 0:
remove_urls_list.extend(remove_urls_list_sub)
return remove_urls_list, deduped_local, counter_local
def find_pair_urls_parallel(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
# compute jaccards of buckets in bin in parallel (parallelism
# limited to # of bins)
num_bins = len(lshcache.bins)
pool = multiprocessing.Pool(num_bins)
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
start_time_local=start_time)
# don't need to pass args and url_doc as they are already shared
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
flush=True)
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
deduped += deduped_local
counter += counter_local
write_remove_urls_list(remove_urls_list, f_out)
print(' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'.format(counter, time.time()\
- start_time, deduped), flush=True)
pool.close()
pool.join()
f_out.close()
print(' Taken time for jaccard similarities {:.2f} seconds'.format(\
time.time() - start_time), flush=True)
def find_pair_urls_sequential(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) <= 1:
continue
bucket_urls = b[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped += deduped_local_sub
counter += counter_local_sub
write_remove_urls_list(remove_urls_list_sub, f_out)
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'.format(counter,
time.time() - start_time, deduped),
flush=True)
f_out.close()
print(' [write]> processed {} documents in {:.2f} '
'seconds and deduped {} documents ...'.format(counter,
time.time() - start_time, deduped),
flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234, help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints',
nargs='*',
default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None, help='Save the fingerprints of the inputs.')
parser.add_argument('--output',
type=str,
default=None,
help='Output file name that consists of all ids'
' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter',
type=int,
default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10, help='Number of bands to use in cache')
parser.add_argument('--num-seeds',
type=int,
default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
parser.add_argument('--jaccard-parallel',
action='store_true',
help='Use this to process large number of documents.')
args = parser.parse_args()
print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time()
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
print("Computing fingerprints", flush=True)
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key), flush=True)
# compute fingerprints in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial, fin, 512)
# traverse all the texts and add fingerprints
for url, text, fingerprint, flag in compute_fingerprint_iter:
counter += 1
if flag:
url_doc[url] = text
lshcache.add_fingerprint(fingerprint, url)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
fin.close()
pool.close()
pool.join()
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
print("Compute jaccard similarity", flush=True)
if args.jaccard_parallel:
find_pair_urls_parallel(args, lshcache, url_doc)
else:
find_pair_urls_sequential(args, lshcache, url_doc)
print('done :-)')
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import logging
import os
import sys
from io import open
import regex as re
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
special_tokens_file = None
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
from cached_path import cached_path
resolved_vocab_file = cached_path(vocab_file)
resolved_merges_file = cached_path(merges_file)
except EnvironmentError:
logger.error("Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path, vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(merges_file, resolved_merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def tokenize(self, text):
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len))
return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(vocab_path):
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
index = len(self.encoder)
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(special_tokens_file))
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file, merge_file, special_tokens_file
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
import time
if __name__ == '__main__':
print('grouping duplicate urls ...')
input = sys.argv[1]
output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {}
index_to_urls = []
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
counter += 1
myjson = json.loads(line)
urls = []
for main_url in myjson.keys():
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if js >= jaccard_similarity_threshold:
urls.append(other_url)
current_index = -1
other_indices = set()
for url in urls:
if url in url_to_index:
if current_index == -1:
current_index = url_to_index[url]
elif current_index != url_to_index[url]:
other_indices.add(url_to_index[url])
if current_index == -1:
current_index = len(index_to_urls)
index_to_urls.append(set())
for url in urls:
url_to_index[url] = current_index
index_to_urls[current_index].add(url)
for index in other_indices:
for url in index_to_urls[index]:
index_to_urls[current_index].add(url)
url_to_index[url] = current_index
index_to_urls[index] = None
if counter % 100000 == 0:
print(' > processed {} lines in {} seconds ...'.format(counter, time.time() - start_time))
total_remove = 0
total_remain = 0
for urls in index_to_urls:
if urls is not None:
if len(urls) > 1:
total_remove += (len(urls) - 1)
total_remain += 1
print('out of {} urls, only {} are unique and {} should be removed'.format(total_remove + total_remain,
total_remain, total_remove))
with open(output, 'wb') as f:
for i, urls in enumerate(index_to_urls):
if urls is not None:
if len(urls) > 1:
myjson = json.dumps({str(i): list(urls)}, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
import time
if __name__ == '__main__':
url_filename = sys.argv[1]
data_filename = sys.argv[2]
output_filename = sys.argv[3]
urls = set()
with open(url_filename, 'r') as f:
for line in f:
myjson = json.loads(line)
for key in myjson:
this_urls = myjson[key]
for i in range(1, len(this_urls)):
urls.add(this_urls[i])
print('will be removing {} urls'.format(len(urls)), flush=True)
written_docs = 0
removed_docs = 0
removed_chars = 0
start_time = time.time()
with open(output_filename, 'wb') as fout:
with open(data_filename, 'r') as fin:
for line in fin:
try:
myjson = json.loads(line)
url = myjson['url']
if url in urls:
print('removing', myjson)
removed_docs += 1
removed_chars += len(myjson['text'])
continue
myjson = json.dumps(myjson, ensure_ascii=False)
fout.write(myjson.encode('utf-8'))
fout.write('\n'.encode('utf-8'))
written_docs += 1
if written_docs % 10000 == 0:
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(time.time() - start_time, written_docs, removed_docs,
removed_chars))
except Exception as e:
print('[SKIPPING]', line, e)
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(time.time() - start_time, written_docs, removed_docs, removed_chars))
print('done :-)')
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
from gpt2_tokenization import GPT2Tokenizer
class Tokenizer:
def __init__(self, cache_dir=None):
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir)
self.tokenizer.max_len = int(1e12)
self.eod_token = self.tokenizer.encoder['<|endoftext|>']
assert self.eod_token < 65535, 'vocab size will not fit in uint16'
print('> GPT2 tokenizer with {} vocab size and eod token {} ...'.format(len(self.tokenizer.encoder),
self.eod_token))
def tokenize_document(self, document):
tokens = self.tokenizer.encode(document)
tokens.append(self.eod_token)
return tokens
# Code taken in large part from https://github.com/jcpeterson/openwebtext
from __future__ import print_function
import argparse
import io
import json
import multiprocessing as mpl
import os
import os.path as op
import sqlite3
import tarfile
import time
import warnings
from glob import glob
from hashlib import sha256
import tldextract
from scrapers import bs4_scraper, newspaper_scraper, raw_scraper
# for backward compatibility
from six.moves.urllib.request import urlopen
from tqdm import tqdm
from utils import chunks, extract_month, linecount, mkdir
parser = argparse.ArgumentParser()
parser.add_argument("url_file", type=str)
parser.add_argument(
"--save_uncompressed",
action="store_true",
default=False,
help="whether to save the raw txt files to disk",
)
parser.add_argument(
"--output",
type=str,
default='raw.json',
help="where to save the output json",
)
parser.add_argument(
"--output_dir",
type=str,
default="scraped",
help="which folder in the working directory to use for output",
)
parser.add_argument(
"--n_procs",
type=int,
default=10,
help="how many processes (cores) to use for parallel scraping",
)
parser.add_argument(
"--timeout",
type=int,
default=-1,
help="maximum scrape time for a single URL; -1 means no limit",
)
parser.add_argument(
"--max_urls",
type=int,
default=-1,
help="maximum # of URLs to scrape; mostly for debugging",
)
parser.add_argument(
"--chunk_size",
type=int,
default=100,
help="how many URLs to scrape before saving to archive",
)
parser.add_argument(
"--scraper",
type=str,
default="newspaper",
choices=["raw", "bs4", "newspaper"],
help="which text/content scraper to use; raw is html",
)
parser.add_argument(
"--compress",
action="store_true",
default=False,
help="whether to output scraped content as compressed archives",
)
parser.add_argument(
"--compress_fmt",
type=str,
default="xz",
choices=["xz", "bz2", "gz"],
help="which archive format to use",
)
parser.add_argument(
"--scraper_memoize",
action="store_true",
default=False,
help="whether to use cache for newspaper",
)
parser.add_argument(
"--show_warnings",
action="store_true",
default=False,
help="whether to show warnings in general during scraping",
)
parser.add_argument(
"--sqlite_meta",
action="store_true",
default=True,
help="whether to use sqlite for storing meta. if false, json will be used instead",
)
args = parser.parse_args()
if not args.show_warnings:
# avoid lots of datetime warnings
warnings.filterwarnings("ignore")
def load_urls(fh, max_urls=-1):
url_entries = enumerate(fh)
if max_urls != -1:
url_entries = list(url_entries)[:max_urls]
return url_entries
def vet_link(link):
# check if server responds with non-200 status code or link points to a
# non-html file
link_type, link_status = "", -1
try:
info = urlopen(link)
link_type = info.headers["Content-Type"]
link_status = info.status
except:
pass
# we want "text/html" only!
is_good_link = False
if "text/html" in link_type and link_status == 200:
is_good_link = True
return is_good_link, link_type
def download(url_entry,
scraper=args.scraper,
save_uncompressed=args.save_uncompressed,
memoize=args.scraper_memoize,
arch_meta=not args.sqlite_meta):
uid, url = url_entry
url = url.strip()
fid = "{:07d}-{}".format(uid, sha256(url.encode()).hexdigest())
data_dir = mkdir(op.join(args.output_dir, "data"))
text_fp = op.join(data_dir, "{}.txt".format(fid))
if arch_meta:
meta_dir = mkdir(op.join(args.output_dir, "meta"))
meta_fp = op.join(meta_dir, "{}.json".format(fid))
# already downloaded!
if op.exists(text_fp):
return
# is_good_link, link_type = vet_link(url)
# if not is_good_link:
# return
if scraper == "bs4":
scrape = bs4_scraper
elif scraper == "newspaper":
scrape = newspaper_scraper
elif scraper == "raw":
scrape = raw_scraper
text, meta = scrape(url, memoize)
ext = tldextract.extract(url)
domain = '.'.join([x for x in ext if x])
meta["domain"] = domain
if text is None or text.strip() == "":
return ("", meta, fid, uid)
if save_uncompressed:
with open(text_fp, "w") as out:
out.write(text)
if arch_meta:
with open(meta_fp, "w") as out:
json.dump(meta, out)
return (text, meta, fid, uid)
def archive_chunk(cid, cdata, out_dir, fmt, arch_meta):
mkdir(out_dir)
texts, metas, fids, uids = zip(*cdata)
data_tar = op.join(out_dir, "{}_data.{}".format(cid, fmt))
if arch_meta:
meta_tar = op.join(out_dir, "{}_meta.{}".format(cid, fmt))
tar_fps, texts, exts = [data_tar, meta_tar], [texts, metas], ["txt", "json"]
else:
tar_fps, texts, exts = [data_tar], [texts], ["txt"]
doc_count = 0
docs_counted = False
for tar_fp, txts, ext in zip(tar_fps, texts, exts):
with tarfile.open(tar_fp, "w:" + fmt) as tar:
for f, fid in zip(txts, fids):
if f == "":
continue
else:
if not docs_counted:
doc_count += 1
if ext == "json":
f = json.dumps(f)
f = f.encode("utf-8")
t = tarfile.TarInfo("{}.{}".format(fid, ext))
t.size = len(f)
tar.addfile(t, io.BytesIO(f))
docs_counted = True
return doc_count
def load_state(url_file):
ckptfile = url_file + '.ckpt'
if op.exists(ckptfile):
with open(ckptfile) as fp:
r = fp.read()
if r == '':
return 0
else:
return int(r)
else:
return 0
def save_state(url_file, cid):
ckptfile = url_file + '.ckpt'
with open(ckptfile, 'w') as fp:
fp.write(str(cid))
def sqlite_conn():
conn = sqlite3.connect('metadata.db')
conn.execute('''
CREATE TABLE IF NOT EXISTS metadata (
fid char(64) not null primary key,
url varchar(2048) not null,
domain varchar(255) not null,
word_count int null,
elapsed int null,
scraper varchar(255) not null,
success boolean not null
);
''')
conn.execute('''
CREATE INDEX IF NOT EXISTS ix_meta_url ON metadata(url);
''')
conn.execute('''
CREATE INDEX IF NOT EXISTS ix_meta_domain ON metadata(domain);
''')
return conn
if __name__ == "__main__":
if args.sqlite_meta:
conn = sqlite_conn()
cur = conn.cursor()
start_elem = load_state(args.url_file)
start_chnk = start_elem // args.chunk_size
f_json = open(args.output, "w")
# URLs we haven't scraped yet (if first run, all URLs in file)
with open(args.url_file) as fh:
url_entries = load_urls(fh, args.max_urls)
pool = mpl.Pool(args.n_procs)
total = linecount(args.url_file) // args.chunk_size
print('Total chunks: ', total)
chunk_iterator = tqdm(enumerate(chunks(url_entries, args.chunk_size, start_elem)), total=total)
# display already-downloaded chunks on progress bar
chunk_iterator.update(start_chnk)
# process one "chunk" of args.chunk_size URLs at a time
for i, chunk in chunk_iterator:
cid = start_chnk + i + 1
tqdm.write("Downloading chunk {}".format(cid))
t1 = time.time()
if args.timeout > 0:
# imap as iterator allows .next() w/ timeout.
# ordered version doesn't seem to work correctly.
# for some reason, you CANNOT track j or chunk[j] in the loop,
# so don't add anything else to the loop below!
# confusingly, chunksize below is unrelated to our chunk_size
chunk_iter = pool.imap_unordered(download, chunk, chunksize=1)
cdata = []
for j in range(len(chunk)):
try:
result = chunk_iter.next(timeout=args.timeout)
cdata.append(result)
except mpl.TimeoutError:
tqdm.write(" --- Timeout Error --- ")
else:
cdata = list(pool.imap(download, chunk, chunksize=1))
tqdm.write("{} / {} downloads timed out".format(len(chunk) - len(cdata), len(chunk)))
tqdm.write("Chunk time: {} seconds".format(time.time() - t1))
# write metadata to sqlite
if args.sqlite_meta:
for text, meta, fid, _ in filter(lambda x: x, cdata):
if text:
params = (fid, meta["url"], meta["domain"], meta["elapsed"], meta["word_count"],
meta["scraper"], True)
else:
params = (fid, meta["url"], meta["domain"], None, None, meta["scraper"], False)
cur.execute(
"insert or ignore into metadata (fid, url, domain, elapsed, word_count, scraper, success) values (?, ?, ?, ?, ?, ?, ?)",
params)
conn.commit()
dump_chunk = []
for text, meta, fid, _ in filter(lambda x: x, cdata):
if text:
line_json = {"text": text, "url": meta["url"]}
dump_chunk.append(json.dumps(line_json) + '\n')
f_json.writelines(dump_chunk)
# archive and save this chunk to file
if args.compress:
tqdm.write("Compressing...")
t2 = time.time()
count = archive_chunk(cid, cdata, args.output_dir, args.compress_fmt, not args.sqlite_meta)
tqdm.write("Archive created in {} seconds".format(time.time() - t2))
tqdm.write("{} out of {} URLs yielded content\n".format(len(list(filter(lambda x: x and x[0], cdata))),
len(chunk)))
save_state(args.url_file, cid * args.chunk_size)
f_json.close()
print("Done!")
import hashlib
import multiprocessing as mp
import os
import traceback
import newspaper
import tldextract
import tqdm
from filter import should_exclude
hash = hashlib.sha256
try:
os.mkdir('data')
except FileExistsError:
pass
def dl(url):
url = url.strip()
if should_exclude(url):
return
ext = tldextract.extract(url)
domain = '.'.join([x for x in ext if x])
fname = 'data/{}-{}.txt'.format(domain, hash(url.encode()).hexdigest())
if os.path.isfile(fname):
return
# print('Downloading', url)
try:
article = newspaper.Article(url, fetch_images=False)
article.download()
article.parse()
except newspaper.article.ArticleException:
# print('Dead link:', url)
return
# traceback.print_exc()
text = article.text
if text.strip() == '':
# print('Empty')
return
with open(fname, 'w') as out:
out.write(text)
if __name__ == '__main__':
p = mp.Pool(100) # num of download threads
with open('urls.txt') as fh:
urls = list(fh)
list(tqdm.tqdm(p.imap(dl, urls), total=len(urls)))
print('Done!')
import re
import tldextract
import tqdm
from utils import linecount
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$',
re.IGNORECASE)
# domains that aren't scraper friendly. do not include subdomains!
exclude_domains = set([
# image & video hosting sites
'imgur.com',
'redd.it',
'instagram.com',
'discord.gg',
'gfycat.com',
'giphy.com',
'reddituploads.com',
'redditmedia.com',
'twimg.com',
'sli.mg',
'magaimg.net',
'flickr.com',
'imgflip.com',
'youtube.com',
'youtu.be',
'youtubedoubler.com',
'vimeo.com',
'twitch.tv',
'streamable.com',
'bandcamp.com',
'soundcloud.com',
# not scraper friendly
'reddit.com',
'gyazo.com',
'github.com',
'xkcd.com',
'twitter.com',
'spotify.com',
'itunes.apple.com',
'facebook.com',
'gunprime.com',
'strawpoll.me',
'voyagefusion.com',
'rollingstone.com',
'google.com',
'timeanddate.com',
'walmart.com',
'roanoke.com',
'spotrac.com',
# original paper excluded wikipedia
'wikipedia.org',
# lots of top posts for this one
'battleforthenet.com',
])
exclude_extensions = ('.png', '.jpg', '.jpeg', '.gif', '.gifv', '.pdf', '.mp4', '.mp3', '.ogv', '.webm', '.doc',
'.docx', '.log', '.csv', '.dat', '.iso', '.bin', '.exe', '.apk', '.jar', '.app', '.ppt', '.pps',
'.pptx', '.xml', '.gz', '.xz', '.bz2', '.tgz', '.tar', '.zip', '.wma', '.mov', '.wmv', '.3gp',
'.svg', '.rar', '.wav', '.avi', '.7z')
def should_exclude(url):
ext = tldextract.extract(url)
domain = '.'.join([x for x in ext if x])
basedomain = '.'.join(ext[-2:])
# Ignore non-URLs
if len(url) <= 8 or ' ' in url or re.match(url_regex, url) is None:
return True
# Ignore excluded domains
if basedomain in exclude_domains or domain in exclude_domains:
return True
# Ignore case-insensitive matches for excluded extensions
if url.lower().split('?')[0].endswith(exclude_extensions):
return True
return False
if __name__ == '__main__':
url_file = 'urls.txt'
filtered_file = 'urls-filtered.txt'
with open(url_file) as urls, open(filtered_file, 'w') as out:
url_len = linecount(url_file)
print("URL file is", url_len, "URLs long.")
url_set = set()
for line in tqdm.tqdm(urls, total=url_len):
if len(line.strip()) == 0:
continue # Skip whitespace-only lines
line = line.strip().split()[0] # Drop any components following whitespace
if should_exclude(line):
continue
url_set.add(line)
for line in tqdm.tqdm(url_set):
out.write(line + '\n')
import datetime
import praw
import psaw
import tqdm
api = psaw.PushshiftAPI()
# all posts until the end of 2017
end_time = int(datetime.datetime(2018, 1, 1).timestamp())
query = api.search_submissions(before=end_time,
filter=['url', 'score'],
sort='desc',
score='>2',
is_self=False,
over_18=False)
with tqdm.tqdm() as pbar:
# download links from submissions
with open('urls.txt', 'w') as fh:
for subm in query:
url = subm.url
# weird issue with psaw/pushshift that breaks score=">2"
if subm.score < 3:
continue
#print(subm.score)
# pbar.write(str(datetime.datetime.fromtimestamp(subm.created_utc)))
pbar.update(1)
fh.write(url + '\n')
fh.flush()
# Code taken in large part from https://github.com/jcpeterson/openwebtext
import time
import unicodedata
import bs4
import newspaper
from filter import should_exclude
from htmlmin import minify
from lxml.html.clean import Cleaner
def find_and_filter_tag(tag, soup):
"""tag specific filter logic"""
candidates = soup.find_all(tag)
candidates = [unicodedata.normalize("NFKD", x.string) for x in candidates if x.string is not None]
if tag == "p":
candidates = [y.strip() for y in candidates if len(y.split(" ")) >= 4]
count = sum(len(y.split(" ")) for y in candidates)
else:
raise NotImplementedError
return (candidates, count)
def raw_scraper(url, memoize):
t1 = time.time()
if should_exclude(url):
# heuristic to make downloading faster
return None, {
"url": url,
"scraper": "raw",
}
try:
cleaner = Cleaner()
cleaner.javascript = True
cleaner.style = True
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
article.download()
html = minify(article.html)
html = cleaner.clean_html(html)
article.parse()
except:
return None, {
"url": url,
"scraper": "raw",
}
if article.text == "":
return None, {
"url": url,
"scraper": "raw",
}
metadata = {"url": url, "elapsed": time.time() - t1, "scraper": "raw"}
return html, metadata
def newspaper_scraper(url, memoize):
t1 = time.time()
if should_exclude(url):
# heuristic to make downloading faster
return None, {
"url": url,
"scraper": "newspaper",
}
try:
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
article.download()
article.parse()
text = article.text
count = len(text.split())
except:
return None, {
"url": url,
"scraper": "newspaper",
}
metadata = {
"url": url,
"word_count": count,
"elapsed": time.time() - t1,
"scraper": "newspaper",
}
return text, metadata
def bs4_scraper(url, memoize):
t1 = time.time()
if should_exclude(url):
# heuristic to make downloading faster
return None, {
"url": url,
"scraper": "bs4",
}
try:
article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
article.download()
html = article.html
soup = bs4.BeautifulSoup(html, "lxml")
text, count = find_and_filter_tag("p", soup)
# DDB: keep text as a single string for consistency with
# newspaper_scraper
text = " ".join(text)
except:
return None, {
"url": url,
"scraper": "bs4",
}
metadata = {
"url": url,
"word_count": count,
"elapsed": time.time() - t1,
"scraper": "bs4",
}
return text, metadata
# Code taken in large part from https://github.com/jcpeterson/openwebtext
import collections
import os
import os.path as op
import re
import tarfile
def extract_month(url_file_name):
month_re = r"(RS_.*2\d{3}-\d{2})"
month = op.split(url_file_name)[-1]
month = re.match(month_re, month).group()
return month
def chunks(l, n, s=0):
"""Yield successive n-sized chunks from l, skipping the first s chunks."""
if isinstance(l, collections.Iterable):
chnk = []
for i, elem in enumerate(l):
if i < s:
continue
chnk.append(elem)
if len(chnk) == n:
yield chnk
chnk = []
if len(chnk) != 0:
yield chnk
else:
for i in range(s, len(l), n):
yield l[i:i + n]
def extract_archive(archive_fp, outdir="."):
with tarfile.open(archive_fp, "r") as tar:
tar.extractall(outdir)
return outdir
def mkdir(fp):
try:
os.makedirs(fp)
except FileExistsError:
pass
return fp
def linecount(filename):
f = open(filename, 'rb')
lines = 0
buf_size = 1024 * 1024
read_f = f.raw.read
buf = read_f(buf_size)
while buf:
lines += buf.count(b'\n')
buf = read_f(buf_size)
return lines
import contextlib
import os
import torch
from dataset.webtext import WebtextDataset
from titans.loss.lm_loss import GPTLMLoss
import colossalai
import colossalai.utils as utils
from colossalai import nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import LinearWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.trainer import Trainer, hooks
from colossalai.utils import is_using_pp
from colossalai.utils.timer import MultiTimer
from colossalai.zero.init_ctx import ZeroInitContext
def calc_local_model_size(model: torch.nn.Module):
numel_per_device = 0
for p in model.parameters():
numel_per_device += p.numel()
return numel_per_device
def main():
parser = colossalai.get_default_parser()
parser.add_argument('--from_torch', default=False, action='store_true')
args = parser.parse_args()
disable_existing_loggers()
if args.from_torch:
colossalai.launch_from_torch(config=args.config)
else:
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
logger = get_dist_logger()
logger.info('Build data loader', ranks=[0])
train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
train_dataloader = utils.get_dataloader(train_ds,
seed=42,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
shuffle=True,
drop_last=True)
logger.info('Build model', ranks=[0])
use_pipeline = is_using_pp()
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
num_chunks = getattr(gpc.config.model, 'num_chunks', 1)
use_zero3 = hasattr(gpc.config, 'zero')
if not use_pipeline:
ctx = contextlib.nullcontext()
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True)
with ctx:
model = gpc.config.model.pop('type')(**gpc.config.model)
else:
pipelinable = PipelinableContext()
with pipelinable:
model = gpc.config.model.pop('type')(**gpc.config.model)
def mask_function(attention_mask=None):
if attention_mask is not None:
batch_size = gpc.config.BATCH_SIZE // gpc.config.NUM_MICRO_BATCHES
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = col_nn.partition_batch(attention_mask)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
return attention_mask
# GPT2_small exec_seq
# (lyl)TODO: The exec_seq for gpt3 will be added here and to_layer_list should be more friendly to use.
exec_seq = ['embed', mask_function, 'blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'blocks.4', 'blocks.5', (mask_function, "front"), \
'blocks.6', 'blocks.7', 'blocks.8', 'blocks.9', 'blocks.10', 'blocks.11', 'norm', 'head']
pipelinable.to_layer_list(exec_seq)
ctx = contextlib.nullcontext()
# (lyl)TODO: Zero context and pipelinable context should be integrated into one context.
if use_zero3:
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True)
with ctx:
model = pipelinable.partition(num_chunks, gpc.pipeline_parallel_size,
gpc.get_local_rank(ParallelMode.PIPELINE))
if use_zero3:
numel = ctx.model_numel_tensor.item()
else:
numel = calc_local_model_size(model)
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
criterion = getattr(gpc.config, 'loss_fn', None)
if criterion is not None:
criterion = criterion.type()
else:
criterion = GPTLMLoss()
logger.info('Build optimizer', ranks=[0])
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
optimizer,
criterion,
train_dataloader=train_dataloader,
lr_scheduler=lr_scheduler)
global_batch_size = gpc.config.BATCH_SIZE * \
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
timier = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timier)
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
hooks.LogMetricByEpochHook(logger),
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
hooks.LogMetricByStepHook(),
hooks.LogMemoryByEpochHook(logger),
]
trainer.fit(train_dataloader=train_dataloader,
epochs=gpc.config.NUM_EPOCHS,
test_interval=1,
hooks=hook_list,
display_progress=True,
return_output_label=False)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment