Commit c44f7622 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Many more features added

parent 6013e23c
......@@ -24,6 +24,7 @@ from functools import partial
import json
import multiprocessing
import nltk
import pickle
import re
import string
import sys
......@@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq):
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf):
text_buf_ngram_free, text_buf, local_ngram):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
#print(" [matched]: {}".format(seq), flush=True)
if args.get_ngram_freq_only:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if seq in local_ngram:
local_ngram[seq] += 1
else:
local_ngram[seq] = 1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if (start_position + len(seq) + 1) < len(text):
text_buf.append(text[start_position + len(seq) + 1:len(text)])
return False
# split the text
text_first, text_second = split_text(text, start_position, \
......@@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams
......@@ -95,18 +109,19 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
text_buf = []
text_buf_ngram_free = []
local_ngram = {}
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
ngram_free = True
# find each max n-grams and check dictionary
for i in range(len(words) - args.ngram_size + 1):
for i in range(len(words) - args.max_ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\
i+args.ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf)
i+args.max_ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# the seq is ngram free? if yes, break
if not check_ngram_free:
......@@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf)
text_buf_ngram_free, text_buf, local_ngram)
# same check as above
if not check_ngram_free:
......@@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
break
# for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.ngram_size > 0:
if ngram_free and len(words) - args.max_ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words) - args.ngram_size):len(words)]
last_seq_start_position = len(words) - args.ngram_size
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
last_seq_start_position = len(words) - args.max_ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.ngram_size:
if ngram_len == args.max_ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
......@@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf)
text_buf_ngram_free, text_buf, local_ngram)
if not check_ngram_free:
ngram_free = False
......@@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
break
# texts are ngram free
if ngram_free:
if ngram_free and not args.get_ngram_freq_only:
text_buf_ngram_free.append(text)
# check if the text has only been trimmed
trimmed = 0
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \
len(myjson[key]):
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
len(text_buf_ngram_free[0]) < len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed
return text_buf_ngram_free, trimmed, local_ngram
# insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = pos
ngrams[seq] = 0
#ngrams[seq] = pos
# insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams):
words, positions = get_words(text)
if len(words) == 0:
if len(words) < args.min_ngram_size:
return
if len(words) < args.ngram_size:
if len(words) < args.max_ngram_size:
insert_dict(words, ngrams, positions[0])
for i in range(len(words) - args.ngram_size+1):
insert_dict(words[i:i+args.ngram_size], ngrams, positions[i])
for i in range(len(words) - args.max_ngram_size+1):
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
# Build ngrams for the lambada dataset
......@@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams):
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
......@@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams):
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
# Default dedup values
parser.add_argument('--ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
# Build ngrams
ngrams = {}
def compute_tasks_ngrams(args, ngrams):
start_time = time.time()
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
......@@ -294,10 +279,10 @@ if __name__ == '__main__':
process_task_lambda(args, args.lambada_path, ngrams)
else:
process_task(args, task_name, ngrams)
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
start_time), flush=True)
print(" Taken time {:.2f}".format(time.time() - start_time), flush=True)
# get the range of the size of the ngrams
def compute_ngram_freq_sorted(args, ngrams):
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
......@@ -309,33 +294,74 @@ if __name__ == '__main__':
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
return ngrams_freq_sorted
id_prefix = '-'.join(args.tasks[::2])
print('Reading file {} and deduping n-grams'.format(args.dedup_dataset))
def get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted):
start_time = time.time()
# get the ngrams frequency
args.get_ngram_freq_only = True
# Open the large file to process in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
counter = 0
for _, _, local_ngram in free_ngrams_abt:
counter += 1
if counter % 1000 == 0:
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
for local_key in local_ngram:
if local_key in ngrams:
ngrams[local_key] += 1
local_ngram = {}
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
start_time), flush=True)
pool.close()
pool.join()
start_time = time.time()
counter_threshold = 0
# Get ngram above theadhold
for local_key, local_val in ngrams.items():
if ngrams[local_key] > args.key_threshold:
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
counter_threshold += 1
ngrams_above_threshold[local_key] = 1
print(' Ngrams above threshold {}'.format(counter_threshold), flush=True)
fin.close()
if args.output is not None:
out_f = open(args.output, 'wb')
def clean_ngrams_above_threshold(args, ngrams_above_threshold, dedup_file, \
dedup_key):
splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0
start_time = time.time()
# Now actually filter the dataset
args.get_ngram_freq_only = False
id_prefix = '-'.join(args.tasks[::2])
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_above_threshold)
# Setup multi-processing.
# Open the large file to process in parallel
num_workers = 40
fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \
ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free, trimmed in free_ngrams:
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams_above_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
out_f = open(args.output, 'wb')
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
for text_buf_ngram_free, trimmed, _ in free_ngrams_clean:
counter += 1
try:
......@@ -361,18 +387,95 @@ if __name__ == '__main__':
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
if args.output is not None:
out_f.close()
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
, flush=True)
pool.close()
pool.join()
out_f.close()
fin.close()
print("Deduped file written to: {}".format(args.output), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}"\
" trimmed {}".format(counter, splitted, ignored, split_mt_thld, \
trimmed_count), flush=True)
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
# Default dedup values
parser.add_argument('--max-ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--min-ngram-size', type=int, default=8,
help='Minimum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--key-threshold', type=int, default=10,
help='Number of keys to consider as threshold')
parser.add_argument('--save-dictionary', type=str, default=None,
help='Save the dictionary')
parser.add_argument('--load-dictionary', type=str, default=None,
help='Load the dictionary')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing
num_workers = 40
if args.load_dictionary is None:
# Build ngrams
ngrams = {}
compute_tasks_ngrams(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
# get ngram freq from large file in parallel
# get ngrams above threshold
ngrams_above_threshold = {}
get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted)
# save the dictionary if needed
if args.save_dictionary is not None:
with open(args.save_dictionary, 'wb') as save_dict_handle:
pickle.dump(ngrams_above_threshold, save_dict_handle)
else:
with open(args.load_dictionary, 'rb') as load_dict_handle:
ngrams_above_threshold = pickle.load(load_dict_handle)
# filter the large file
if args.output is not None:
clean_ngrams_above_threshold(args, ngrams_above_threshold, \
dedup_file, dedup_key)
print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment