Commit 90c9e3f2 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Other tasks dedup added

parent ebc4a408
...@@ -19,6 +19,7 @@ All split documents with less than 200 characters got filtered. Any document ...@@ -19,6 +19,7 @@ All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well. with more than 10 splits got filtered as well.
""" """
import argparse
from functools import partial from functools import partial
import json import json
import multiprocessing import multiprocessing
...@@ -36,40 +37,20 @@ def get_words(text): ...@@ -36,40 +37,20 @@ def get_words(text):
positions.append(match.start()) positions.append(match.start())
return words, positions return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len, # splits the text
splits_count, split_window_each_size): def split_text(text, start_position, remove_char_each_side, seq):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson['text']]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
not_ngram_free = True
punctuations = ".!?"
# find n-grams
for i in range(len(words) - ngram_size + 1):
seq = " ".join(words[i:i+ngram_size])
if seq in ngrams:
# splits the text
# first part of the text # first part of the text
pos = positions[i] - split_window_each_size punctuations = ".!?"
pos = start_position - remove_char_each_side
text_first = "" text_first = ""
while pos > 0 and not text[pos] in punctuations: while pos > 0 and not text[pos] in punctuations:
pos -= 1 pos -= 1
if pos > 0: if pos > 0:
text_first = text[0:pos+1] text_first = text[0:pos+1]
pos = positions[i] + split_window_each_size
# add length of seq and remove_char_each_side
pos = start_position + len(seq) + remove_char_each_side
# last part of the text # last part of the text
text_second = "" text_second = ""
while pos < len(text) and not text[pos] in punctuations: while pos < len(text) and not text[pos] in punctuations:
...@@ -77,78 +58,252 @@ def free_ngram(line, ngrams, ngram_size, filter_text_len, ...@@ -77,78 +58,252 @@ def free_ngram(line, ngrams, ngram_size, filter_text_len,
if pos + 1 < len(text): if pos + 1 < len(text):
text_second = text[pos+1:len(text)] text_second = text[pos+1:len(text)]
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
# split the text
text_first, text_second = split_text(text, start_position, \
args.remove_char_each_side, seq)
# first part of ngrams free # first part of ngrams free
if len(text_first) > filter_text_len: if len(text_first) > args.filter_text_char_len:
text_buf_ngram_free.append(text_first) text_buf_ngram_free.append(text_first)
# add second part for further processing # add second part for further processing
if len(text_second) > filter_text_len: if len(text_second) > args.filter_text_char_len:
text_buf.append(text_second) text_buf.append(text_second)
not_ngram_free = False
return False # not ngram free
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson[key]]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
ngram_free = True
# find each max n-grams and check dictionary
for i in range(len(words) - args.ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\
i+args.ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf)
# the seq is ngram free? if yes, break
if not check_ngram_free:
ngram_free = False
break
# if max ngrams doesn't match, check if any other lower n-grams
# within max ngram macthes
for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf)
# same check as above
if not check_ngram_free:
ngram_free = False
break
# check break from lower than max ngram loop above
if not ngram_free:
break
# for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words) - args.ngram_size):len(words)]
last_seq_start_position = len(words) - args.ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
for i in range(len(last_seq_words) - ngram_len + 1):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf)
if not check_ngram_free:
ngram_free = False
break
if not ngram_free:
break break
# text are ngram free # texts are ngram free
if not_ngram_free: if ngram_free:
text_buf_ngram_free.append(text) text_buf_ngram_free.append(text)
return text_buf_ngram_free # check if the text has only been trimmed
trimmed = 0
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \
len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed
if __name__ == '__main__': # insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = pos
print('finding possible duplicate content ...') # insert each ngram from text into the ngrams dictionary
main_file = sys.argv[1] # lambada file def compute_ngrams_insert_dict(args, text, ngrams):
dedup_file = sys.argv[2] # Book corpus words, positions = get_words(text)
output_file = sys.argv[3] #Filtered book corpus if len(words) == 0:
ngrams = {} return
id_prefix = "lambada"
# we use 13-grams, any text less than 200 characters got removed if len(words) < args.ngram_size:
# any text splitted more than 10 got removed as well insert_dict(words, ngrams, positions[0])
ngram_size = 13
filter_text_len = 200
splits_count = 10
split_window_each_size = 200
print('Reading file {} and computing ngrams'.format(main_file)) for i in range(len(words) - args.ngram_size+1):
with open(main_file, 'r') as f: insert_dict(words[i:i+args.ngram_size], ngrams, positions[i])
# Build ngrams for the lambada dataset
def process_task_lambda(args, task_file, ngrams):
print(' reading from {} and computing ngrams'.format(task_file))
with open(task_file, 'r') as f:
for line in f: for line in f:
try: try:
myjson = json.loads(line) myjson = json.loads(line)
words, positions = get_words(myjson['text']) text = myjson['text']
for i in range(len(words) - ngram_size+1): compute_ngrams_insert_dict(args, text, ngrams)
seq = " ".join(words[i:i+ngram_size]) except Exception as e:
if seq not in ngrams: print('Error:', e)
ngrams[seq] = positions[i] print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the squad v2 dataset
def process_task_squad(args, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
# using squad data from datasets
from datasets import load_dataset
squad_v2 = load_dataset('squad_v2', split='validation')
for line in squad_v2:
try:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
print("ngrams size {}".format(len(ngrams))) print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
# Default dedup values
parser.add_argument('--ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
# Build ngrams
ngrams = {}
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
if task_name == 'squad':
process_task_squad(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
id_prefix = '-'.join(args.tasks[::2])
print('Reading file {} and deduping n-grams'.format(args.dedup_dataset))
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0 counter = 0
start_time = time.time() start_time = time.time()
out_f = open(output_file, 'wb') out_f = open(args.output, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0 splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing. # Setup multi-processing.
num_workers = 40 num_workers = 1 #40
fin = open(dedup_file, 'r', encoding='utf-8') fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers) pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size, free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \
filter_text_len=filter_text_len, splits_count=splits_count, ngrams_freq_sorted=ngrams_freq_sorted)
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25) free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free in free_ngrams: for text_buf_ngram_free, trimmed in free_ngrams:
counter += 1 counter += 1
try: try:
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1: if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1) splitted += (len(text_buf_ngram_free) - 1)
if len(text_buf_ngram_free) == 0: if len(text_buf_ngram_free) == 0:
ignored += 1 ignored += 1
# more than 10 splits ignored # more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count: if len(text_buf_ngram_free) > args.splits_count:
text_buf_ngram_free = [] text_buf_ngram_free = []
split_mt_thld += 1 split_mt_thld += 1
...@@ -167,7 +322,11 @@ if __name__ == '__main__': ...@@ -167,7 +322,11 @@ if __name__ == '__main__':
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True) out_f.close()
print("Total docs {} splitted {} ignored {} docs with many splits {}".\ fin.close()
format(counter, splitted, ignored, split_mt_thld), flush=True)
print("Deduped file written to: {}".format(args.output), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}"\
" trimmed {}".format(counter, splitted, ignored, split_mt_thld, \
trimmed_count), flush=True)
print('done :-)') print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment