Commit 90c9e3f2 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Other tasks dedup added

parent ebc4a408
...@@ -19,6 +19,7 @@ All split documents with less than 200 characters got filtered. Any document ...@@ -19,6 +19,7 @@ All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well. with more than 10 splits got filtered as well.
""" """
import argparse
from functools import partial from functools import partial
import json import json
import multiprocessing import multiprocessing
...@@ -36,13 +37,59 @@ def get_words(text): ...@@ -36,13 +37,59 @@ def get_words(text):
positions.append(match.start()) positions.append(match.start())
return words, positions return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len, # splits the text
splits_count, split_window_each_size): def split_text(text, start_position, remove_char_each_side, seq):
# first part of the text
punctuations = ".!?"
pos = start_position - remove_char_each_side
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
# add length of seq and remove_char_each_side
pos = start_position + len(seq) + remove_char_each_side
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
# split the text
text_first, text_second = split_text(text, start_position, \
args.remove_char_each_side, seq)
# first part of ngrams free
if len(text_first) > args.filter_text_char_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > args.filter_text_char_len:
text_buf.append(text_second)
return False # not ngram free
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams # remove all the ngrams
try: try:
myjson = json.loads(line) myjson = json.loads(line)
text_buf = [myjson['text']] text_buf = [myjson[key]]
except Exception as e: except Exception as e:
print("Error: {}".format(e), flush=True) print("Error: {}".format(e), flush=True)
text_buf = [] text_buf = []
...@@ -53,102 +100,210 @@ def free_ngram(line, ngrams, ngram_size, filter_text_len, ...@@ -53,102 +100,210 @@ def free_ngram(line, ngrams, ngram_size, filter_text_len,
# get the first one from the buffer # get the first one from the buffer
text = text_buf.pop(0) text = text_buf.pop(0)
words, positions = get_words(text) words, positions = get_words(text)
not_ngram_free = True ngram_free = True
punctuations = ".!?" # find each max n-grams and check dictionary
# find n-grams for i in range(len(words) - args.ngram_size + 1):
for i in range(len(words) - ngram_size + 1): check_ngram_free = check_and_clean_text(args, words[i:\
seq = " ".join(words[i:i+ngram_size]) i+args.ngram_size], ngrams, text, positions[i], \
if seq in ngrams: text_buf_ngram_free, text_buf)
# splits the text # the seq is ngram free? if yes, break
# first part of the text if not check_ngram_free:
pos = positions[i] - split_window_each_size ngram_free = False
text_first = "" break
while pos > 0 and not text[pos] in punctuations:
pos -= 1 # if max ngrams doesn't match, check if any other lower n-grams
if pos > 0: # within max ngram macthes
text_first = text[0:pos+1] for ngram_len, _ in ngrams_freq_sorted:
pos = positions[i] + split_window_each_size check_ngram_free = check_and_clean_text(args, words[i:\
# last part of the text i+ngram_len], ngrams, text, positions[i], \
text_second = "" text_buf_ngram_free, text_buf)
while pos < len(text) and not text[pos] in punctuations:
pos += 1 # same check as above
if pos + 1 < len(text): if not check_ngram_free:
text_second = text[pos+1:len(text)] ngram_free = False
break
# first part of ngrams free
if len(text_first) > filter_text_len: # check break from lower than max ngram loop above
text_buf_ngram_free.append(text_first) if not ngram_free:
# add second part for further processing
if len(text_second) > filter_text_len:
text_buf.append(text_second)
not_ngram_free = False
break break
# text are ngram free # for the last max n-gram, check all the lower ngrams in it
if not_ngram_free: if ngram_free and len(words) - args.ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words) - args.ngram_size):len(words)]
last_seq_start_position = len(words) - args.ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
for i in range(len(last_seq_words) - ngram_len + 1):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf)
if not check_ngram_free:
ngram_free = False
break
if not ngram_free:
break
# texts are ngram free
if ngram_free:
text_buf_ngram_free.append(text) text_buf_ngram_free.append(text)
return text_buf_ngram_free # check if the text has only been trimmed
trimmed = 0
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \
len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed
if __name__ == '__main__': # insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = pos
print('finding possible duplicate content ...') # insert each ngram from text into the ngrams dictionary
main_file = sys.argv[1] # lambada file def compute_ngrams_insert_dict(args, text, ngrams):
dedup_file = sys.argv[2] # Book corpus words, positions = get_words(text)
output_file = sys.argv[3] #Filtered book corpus if len(words) == 0:
ngrams = {} return
id_prefix = "lambada"
# we use 13-grams, any text less than 200 characters got removed if len(words) < args.ngram_size:
# any text splitted more than 10 got removed as well insert_dict(words, ngrams, positions[0])
ngram_size = 13
filter_text_len = 200 for i in range(len(words) - args.ngram_size+1):
splits_count = 10 insert_dict(words[i:i+args.ngram_size], ngrams, positions[i])
split_window_each_size = 200
print('Reading file {} and computing ngrams'.format(main_file))
with open(main_file, 'r') as f: # Build ngrams for the lambada dataset
def process_task_lambda(args, task_file, ngrams):
print(' reading from {} and computing ngrams'.format(task_file))
with open(task_file, 'r') as f:
for line in f: for line in f:
try: try:
myjson = json.loads(line) myjson = json.loads(line)
words, positions = get_words(myjson['text']) text = myjson['text']
for i in range(len(words) - ngram_size+1): compute_ngrams_insert_dict(args, text, ngrams)
seq = " ".join(words[i:i+ngram_size])
if seq not in ngrams:
ngrams[seq] = positions[i]
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
print("ngrams size {}".format(len(ngrams))) print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the squad v2 dataset
def process_task_squad(args, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
# using squad data from datasets
from datasets import load_dataset
squad_v2 = load_dataset('squad_v2', split='validation')
for line in squad_v2:
try:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
# Default dedup values
parser.add_argument('--ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
# Build ngrams
ngrams = {}
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
if task_name == 'squad':
process_task_squad(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
id_prefix = '-'.join(args.tasks[::2])
print('Reading file {} and deduping n-grams'.format(args.dedup_dataset))
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0 counter = 0
start_time = time.time() start_time = time.time()
out_f = open(output_file, 'wb') out_f = open(args.output, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0 splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing. # Setup multi-processing.
num_workers = 40 num_workers = 1 #40
fin = open(dedup_file, 'r', encoding='utf-8') fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers) pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size, free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \
filter_text_len=filter_text_len, splits_count=splits_count, ngrams_freq_sorted=ngrams_freq_sorted)
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25) free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free in free_ngrams: for text_buf_ngram_free, trimmed in free_ngrams:
counter += 1 counter += 1
try: try:
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1: if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1) splitted += (len(text_buf_ngram_free) - 1)
if len(text_buf_ngram_free) == 0: if len(text_buf_ngram_free) == 0:
ignored += 1 ignored += 1
# more than 10 splits ignored # more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count: if len(text_buf_ngram_free) > args.splits_count:
text_buf_ngram_free = [] text_buf_ngram_free = []
split_mt_thld += 1 split_mt_thld += 1
...@@ -167,7 +322,11 @@ if __name__ == '__main__': ...@@ -167,7 +322,11 @@ if __name__ == '__main__':
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True) out_f.close()
print("Total docs {} splitted {} ignored {} docs with many splits {}".\ fin.close()
format(counter, splitted, ignored, split_mt_thld), flush=True)
print("Deduped file written to: {}".format(args.output), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}"\
" trimmed {}".format(counter, splitted, ignored, split_mt_thld, \
trimmed_count), flush=True)
print('done :-)') print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment