"vscode:/vscode.git/clone" did not exist on "853d4eb9cf51fced975a428de15428fb4860a449"
Commit 7a5768ac authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_dedup' into 'main'

Training data and task deduplication

See merge request ADLR/megatron-lm!252
parents f32a638d 0fa728ac
......@@ -26,7 +26,8 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
```
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
......@@ -46,10 +47,13 @@ shuf <cleaned deduped data file> -o train_data.json
# Deduplicating ngrams
To deduplicate the downstream tasks from the training dataset, we run the following command.
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
```
python filter_ngrams.py <down stream task dataset> <training dataset to deduplicate> <output training dataset>
python filter_ngrams.py --tasks <name of he task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
We use 13-grams for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times.
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
"""
This code adds id to each json object in a json file. User can add prefix
to the ids.
"""
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', type=str, default=None, help='Input'\
' json file where id needs to be added')
parser.add_argument('--output-file', type=str, default=None, help=\
'Output file name with id')
parser.add_argument('--id-prefix', type=str, default=None, help=\
'Id prefix')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('Adding ids to dataset ...')
f_input = open(args.input_file, 'r', encoding='utf-8')
f_output = open(args.output_file, 'wb')
unique_ids = 1
start_time = time.time()
for row in f_input:
each_row = json.loads(row)
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
each_row['adlr_id'] = adlr_id_string
myjson = json.dumps(each_row, ensure_ascii=False)
f_output.write(myjson.encode('utf-8'))
f_output.write('\n'.encode('utf-8'))
if unique_ids % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
unique_ids, time.time() - start_time), flush=True)
unique_ids += 1
# Close the file.
f_input.close()
f_output.close()
print('done :-)', flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import argparse
from functools import partial
import glob
import ftfy
import json
from langdetect import detect
import multiprocessing
import os
from pathlib import Path
import re
import time
def process_doc(json_line, args):
# Read the line.
document = json.loads(json_line)
text = document['text']
output = {'remove_512': False, 'remove_256_javascript': False, \
'remove_512_non_english': False, 'ftfy_fix_text': False, \
'general_cleaning': False}
try:
# Reomove all docs with less than 512 characters
if "remove_512" in args.tasks:
if len(text) < 512:
output['remove_512'] = True
return output, text, document, True
# Remove docs if less than 256 character length and contains Javascript
if "remove_256_javascript" in args.tasks:
if len(text) < 256 and 'javascript' in text.lower():
output['remove_256_javascript'] = True
return output, text, document, True
# Remove docs < 512 and nonenglish
if "remove_512_non_english" in args.tasks:
if len(text) < 512 and detect(text) != 'en':
output['remove_512_non_english'] = True
return output, text, document, True
# Fix the text using ftfy, don't remove the text, hence return False
if "ftfy_fix_text" in args.tasks:
fixed_text = ftfy.fix_text(text)
output['ftfy_fix_text'] = True
return output, fixed_text, document, False
# Cleaning extra spaces and newlines
if "general_cleaning" in args.tasks:
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output['general_cleaning'] = True
return output, cleaned_text, document, False
except Exception as e:
print('Error: *************************\n{}\ntext: {}'.format(e, \
text), flush=True)
return output, text, document, True
# don't remove
return output, text, document, False
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
print(' > working on {} ...'.format(input_file), flush=True)
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
= num_ftfy_fix_text = num_general_cleaning = 0
# Output file and counters.
output_cleaned = open(output_f_cleaned, 'wb')
output_filtered = open(output_f_filtered, 'wb')
start_time = time.time()
# Setup multi-processing.
num_workers = 40
fin = open(input_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
process_doc_partial = partial(process_doc, args=args)
processed_docs = pool.imap(process_doc_partial, fin, 500)
# Process documents.
for output, text, document, to_filter in processed_docs:
num_docs += 1
num_remove_512 += 1 if output['remove_512'] else 0
num_remove_java += 1 if output['remove_256_javascript'] else 0
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
else 0
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
num_general_cleaning += 1 if output['general_cleaning'] else 0
document['text'] = text
myjson = json.dumps(document, ensure_ascii=False)
if to_filter:
output_filtered.write(myjson.encode('utf-8'))
output_filtered.write('\n'.encode('utf-8'))
else:
output_cleaned.write(myjson.encode('utf-8'))
output_cleaned.write('\n'.encode('utf-8'))
if num_docs % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format(
num_docs, time.time() - start_time), flush=True)
# Close the file.
output_cleaned.close()
output_filtered.close()
fin.close()
# Print stats.
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
format(num_docs, num_remove_512, num_remove_java,\
num_remove_512_non_english, num_ftfy_fix_text, \
num_general_cleaning), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-files', nargs = '*', required=True, default=\
None, help = 'Input json files that needs to be'\
' cleaned')
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
help = 'Tasks to perform on the input files, ' \
'such as remove_512, remove_256_javascript, ' \
'remove_512_non_english, ftfy_fix_text, and ' \
'general_cleaning. 256 or 512 means the number' \
' of characters.')
parser.add_argument('--output-path', type=str, default=None,
help='Directory where the output should go')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('cleanup dataset ...')
for input_file in args.input_files:
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
.name)
output_f_cleaned = os.path.join(args.output_path, input_filename + \
"_cleaned" + input_filename_ext)
output_f_filtered = os.path.join(args.output_path, input_filename + \
"_filtered" + input_filename_ext)
process_set(args, input_file, output_f_cleaned, output_f_filtered)
print('done :-)', flush=True)
......@@ -19,10 +19,12 @@ All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
import argparse
from functools import partial
import json
import multiprocessing
import nltk
import pickle
import re
import string
import sys
......@@ -36,138 +38,455 @@ def get_words(text):
positions.append(match.start())
return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len,
splits_count, split_window_each_size):
# splits the text
def split_text(text, start_position, remove_char_each_side, seq):
# first part of the text
punctuations = ".!?"
pos = start_position - remove_char_each_side
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
# add length of seq and remove_char_each_side
pos = start_position + len(seq) + remove_char_each_side
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf, local_ngram):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
if args.get_ngram_freq_only:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if seq in local_ngram:
local_ngram[seq] += 1
else:
local_ngram[seq] = 1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if (start_position + len(seq) + 1) < len(text):
text_buf.append(text[start_position + len(seq) + 1:len(text)])
return False
# split the text
text_first, text_second = split_text(text, start_position, \
args.remove_char_each_side, seq)
# first part of ngrams free
if len(text_first) > args.filter_text_char_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > args.filter_text_char_len:
text_buf.append(text_second)
return False # not ngram free
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson['text']]
text_buf = [myjson[key]]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
local_ngram = {}
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
not_ngram_free = True
punctuations = ".!?"
# find n-grams
for i in range(len(words) - ngram_size + 1):
seq = " ".join(words[i:i+ngram_size])
if seq in ngrams:
# splits the text
# first part of the text
pos = positions[i] - split_window_each_size
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
pos = positions[i] + split_window_each_size
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
# first part of ngrams free
if len(text_first) > filter_text_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > filter_text_len:
text_buf.append(text_second)
not_ngram_free = False
ngram_free = True
# find each max n-grams and check dictionary
for i in range(len(words) - args.max_ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\
i+args.max_ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# the seq is ngram free? if yes, break
if not check_ngram_free:
ngram_free = False
break
# if max ngrams doesn't match, check if any other lower n-grams
# within max ngram macthes
for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# same check as above
if not check_ngram_free:
ngram_free = False
break
# check break from lower than max ngram loop above
if not ngram_free:
break
# text are ngram free
if not_ngram_free:
# for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.max_ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
last_seq_start_position = len(words) - args.max_ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.max_ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
for i in range(len(last_seq_words) - ngram_len + 1):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf, local_ngram)
if not check_ngram_free:
ngram_free = False
break
if not ngram_free:
break
# texts are ngram free
if ngram_free and not args.get_ngram_freq_only:
text_buf_ngram_free.append(text)
return text_buf_ngram_free
# check if the text has only been trimmed
trimmed = 0
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
len(text_buf_ngram_free[0]) < len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed, myjson, local_ngram
if __name__ == '__main__':
# insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = 0
#ngrams[seq] = pos
print('finding possible duplicate content ...')
main_file = sys.argv[1] # lambada file
dedup_file = sys.argv[2] # Book corpus
output_file = sys.argv[3] #Filtered book corpus
ngrams = {}
id_prefix = "lambada"
# insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams):
words, positions = get_words(text)
if len(words) < args.min_ngram_size:
return
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
ngram_size = 13
filter_text_len = 200
splits_count = 10
split_window_each_size = 200
if len(words) < args.max_ngram_size:
insert_dict(words, ngrams, positions[0])
for i in range(len(words) - args.max_ngram_size+1):
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
print('Reading file {} and computing ngrams'.format(main_file))
with open(main_file, 'r') as f:
# Build ngrams for the lambada dataset
def process_task_lambda(args, task_file, ngrams):
print(' reading from {} and computing ngrams'.format(task_file))
with open(task_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
words, positions = get_words(myjson['text'])
for i in range(len(words) - ngram_size+1):
seq = " ".join(words[i:i+ngram_size])
if seq not in ngrams:
ngrams[seq] = positions[i]
text = myjson['text']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print("ngrams size {}".format(len(ngrams)))
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
from datasets import load_dataset
entities_in_ngrams = len(ngrams)
# load the dataset
if task_name == 'squad':
dataset = load_dataset('squad_v2', split='validation')
elif task_name == 'natural_questions':
dataset = load_dataset('natural_questions', split='validation')
elif task_name == 'triviaqa':
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
elif task_name == 'webqa':
dataset = load_dataset('web_questions', split='test')
elif task_name == 'race':
dataset = load_dataset('race', 'all', split='test')
elif task_name == 'drop':
dataset = load_dataset('drop', split='validation')
elif task_name == 'coqa':
dataset = load_dataset('coqa', split='validation')
elif task_name == 'piqa':
dataset = load_dataset('piqa', split='test')
else:
print("Invalid task name: {}".format(task_name), flush=True)
return
# read the dataset and add to ngrams
for line in dataset:
try:
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'natural_questions':
text = line['question']['text']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'coqa':
all_questions = line['questions']
for question in all_questions:
compute_ngrams_insert_dict(args, question, ngrams)
elif task_name == 'piqa':
text = line['goal']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
def compute_tasks_ngrams(args, ngrams):
start_time = time.time()
out_f = open(output_file, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
else:
process_task(args, task_name, ngrams)
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
start_time), flush=True)
# Setup multi-processing.
num_workers = 40
def compute_ngram_freq_sorted(args, ngrams):
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
return ngrams_freq_sorted
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted):
start_time = time.time()
# get the ngrams frequency
args.get_ngram_freq_only = True
# Open the large file to process in parallel
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
counter = 0
for _, _, _, local_ngram in free_ngrams_abt:
counter += 1
if counter % 1000 == 0:
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
for local_key in local_ngram:
if local_key in ngrams:
ngrams[local_key] += 1
local_ngram = {}
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
start_time), flush=True)
pool.close()
pool.join()
start_time = time.time()
counter_threshold = 0
# Get ngram below theadhold
for local_key, local_val in ngrams.items():
if ngrams[local_key] < args.key_threshold:
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
counter_threshold += 1
ngrams_below_threshold[local_key] = 1
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
fin.close()
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
dedup_key):
start_time = time.time()
# Now actually filter the dataset
args.get_ngram_freq_only = False
#id_prefix = '-'.join(args.tasks[::2])
id_prefix = '-'.join(args.tasks[::1])
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
# Open the large file to process in parallel
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size,
filter_text_len=filter_text_len, splits_count=splits_count,
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
out_f = open(args.output, 'wb')
for text_buf_ngram_free in free_ngrams:
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
counter += 1
try:
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1)
splitted += 1
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count:
if len(text_buf_ngram_free) > args.splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
+ '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if args.output is not None:
if "split_id" in myjson:
use_prefix = myjson["split_id"] + "-"
else:
use_prefix = ""
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(\
counter)) + '-{:04d}'.format(int(i))
myjson[dedup_key] = text_buf_ngram_free[i]
myjson["split_id"] = use_prefix + split_id_string
outjson = json.dumps(myjson, ensure_ascii=False)
#outjson = json.dumps({"text":text_buf_ngram_free[i],
# id_prefix+"_split_id":split_id_string},
# ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}".\
format(counter, splitted, ignored, split_mt_thld), flush=True)
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
, flush=True)
pool.close()
pool.join()
out_f.close()
fin.close()
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
parser.add_argument('--num-threads', type=int, default=40,
help='Number of threads to use')
# Default dedup values
parser.add_argument('--max-ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--min-ngram-size', type=int, default=8,
help='Minimum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--key-threshold', type=int, default=10,
help='Number of keys to consider as threshold')
parser.add_argument('--save-dictionary', type=str, default=None,
help='Save the dictionary')
parser.add_argument('--load-dictionary', type=str, default=None,
help='Load the dictionary')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing
num_workers = args.num_threads
if args.load_dictionary is None:
# Build ngrams
ngrams = {}
compute_tasks_ngrams(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
# get ngram freq from large file in parallel
# get ngrams below threshold
ngrams_below_threshold = {}
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted)
# save the dictionary if needed
if args.save_dictionary is not None:
with open(args.save_dictionary, 'wb') as save_dict_handle:
pickle.dump(ngrams_below_threshold, save_dict_handle)
else:
with open(args.load_dictionary, 'rb') as load_dict_handle:
ngrams_below_threshold = pickle.load(load_dict_handle)
# filter the large file
if args.output is not None:
clean_ngrams_below_threshold(args, ngrams_below_threshold, \
dedup_file, dedup_key)
print('done :-)')
......@@ -14,13 +14,16 @@
# limitations under the License.
import argparse
from functools import partial
import itertools
import json
from lsh import cache, minhash
import multiprocessing
import numpy as np
import time
import pickle
import sys
import os
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
......@@ -31,11 +34,158 @@ def shingles(text, char_ngram=5):
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b):
def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b
union = set_a | set_b
return len(intersection) / len(union)
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
def url_pairs_to_remove(args, bucket_urls, url_doc):
remove_urls_list = []
deduped_local, counter_local = 0, 0
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = []
main_url = items[np.random.randint(0, len(items))]
main_dhingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter_local += 1
other_url = items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_dhingles, other_shingles, args)
except Exception as e:
print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped_local += 1
bucket_urls.remove(other_url)
bucket_urls.remove(main_url)
if len(remove_urls) > 0:
remove_urls_list.append({main_url: remove_urls})
iteration += 1
return remove_urls_list, deduped_local, counter_local
def write_remove_urls_list(remove_urls_list, f_out):
if len(remove_urls_list) > 0:
for each_url_remove in remove_urls_list:
myjson = json.dumps(each_url_remove, ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
def compute_jaccard(each_bin, num_bins, start_time_local):
remove_urls_list = []
deduped_local, counter_local, bucket_local = 0, 0, 0
for bucket_id in each_bin:
bucket_local += 1
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
print("Counter {}, progress {:.2f} time {:.2f}".\
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
time.time() - start_time_local), flush=True)
if len(each_bin[bucket_id]) <= 1:
continue
bucket_urls = each_bin[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped_local += deduped_local_sub
counter_local += counter_local_sub
if len(remove_urls_list_sub) > 0:
remove_urls_list.extend(remove_urls_list_sub)
return remove_urls_list, deduped_local, counter_local
def find_pair_urls_parallel(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
# compute jaccards of buckets in bin in parallel (parallelism
# limited to # of bins)
num_bins = len(lshcache.bins)
pool = multiprocessing.Pool(num_bins)
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
start_time_local=start_time)
# don't need to pass args and url_doc as they are already shared
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
flush=True)
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
deduped += deduped_local
counter += counter_local
write_remove_urls_list(remove_urls_list, f_out)
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.format(counter, time.time()\
- start_time, deduped), flush=True)
pool.close()
pool.join()
f_out.close()
print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
time.time() - start_time), flush=True)
def find_pair_urls_sequential(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) <= 1:
continue
bucket_urls = b[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped += deduped_local_sub
counter += counter_local_sub
write_remove_urls_list(remove_urls_list_sub, f_out)
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
f_out.close()
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
if __name__ == '__main__':
......@@ -55,17 +205,30 @@ if __name__ == '__main__':
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter', type=int, default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10,
help='Number of bands to use in cache')
parser.add_argument('--num-seeds', type=int, default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
parser.add_argument('--jaccard-parallel', action='store_true',
help='Use this to process large number of documents.')
args = parser.parse_args()
print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher)
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {}
......@@ -91,31 +254,36 @@ if __name__ == '__main__':
counter = 0
start_time = time.time()
print("Computing fingerprints", flush=True)
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
print("Computing fingerprints", flush=True)
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key),
flush=True)
# compute fingerprints in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
fin, 512)
# traverse all the texts and add fingerprints
with open(input_file, 'r') as f_input:
for line in f_input:
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
for url, text, fingerprint, flag in compute_fingerprint_iter:
counter += 1
if flag:
url_doc[url] = text
lshcache.add_fingerprint(fingerprint, url)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
fin.close()
pool.close()
pool.join()
# Save the fingerprints if needed
if args.save_fingerprints is not None:
......@@ -125,40 +293,13 @@ if __name__ == '__main__':
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
counter = 0
start_time = time.time()
deduped = 0
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
f_out = open(args.output, 'wb')
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) > 1:
items = list(b[bucket_id])
main_url = items[0]
main_dhingles = shingles(url_doc[main_url])
remove_urls = []
for i in range(1, len(items)):
counter += 1
other_url= items[i]
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_dhingles, other_shingles)
except Exception as e:
print('Error:', e)
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped += 1
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
f_out.close()
print("Compute jaccard similarity", flush=True)
if args.jaccard_parallel:
find_pair_urls_parallel(args, lshcache, url_doc)
else:
find_pair_urls_sequential(args, lshcache, url_doc)
print('done :-)')
......@@ -18,10 +18,6 @@ import time
import sys
def is_similar(jaccard_similarity):
return (jaccard_similarity >= 0.7)
if __name__ == '__main__':
......@@ -29,6 +25,10 @@ if __name__ == '__main__':
input = sys.argv[1]
output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {}
index_to_urls = []
......@@ -43,7 +43,7 @@ if __name__ == '__main__':
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if is_similar(js):
if js >= jaccard_similarity_threshold:
urls.append(other_url)
current_index = -1
other_indices = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment