"git@developer.sourcefind.cn:change/sglang.git" did not exist on "84022c0e563c9b519fa16bf0eb7f9d750a728531"
Commit c44f7622 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Many more features added

parent 6013e23c
...@@ -24,6 +24,7 @@ from functools import partial ...@@ -24,6 +24,7 @@ from functools import partial
import json import json
import multiprocessing import multiprocessing
import nltk import nltk
import pickle
import re import re
import string import string
import sys import sys
...@@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq): ...@@ -61,11 +62,23 @@ def split_text(text, start_position, remove_char_each_side, seq):
return text_first, text_second return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \ def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf): text_buf_ngram_free, text_buf, local_ngram):
seq = " ".join(words) seq = " ".join(words)
if seq in ngrams: if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True) #print(" [matched]: {}".format(seq), flush=True)
if args.get_ngram_freq_only:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if seq in local_ngram:
local_ngram[seq] += 1
else:
local_ngram[seq] = 1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if (start_position + len(seq) + 1) < len(text):
text_buf.append(text[start_position + len(seq) + 1:len(text)])
return False
# split the text # split the text
text_first, text_second = split_text(text, start_position, \ text_first, text_second = split_text(text, start_position, \
...@@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \ ...@@ -84,6 +97,7 @@ def check_and_clean_text(args, words, ngrams, text, start_position, \
# ngram free # ngram free
return True return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams # remove all the ngrams
...@@ -95,6 +109,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -95,6 +109,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
text_buf = [] text_buf = []
text_buf_ngram_free = [] text_buf_ngram_free = []
local_ngram = {}
while len(text_buf) > 0: while len(text_buf) > 0:
# get the first one from the buffer # get the first one from the buffer
...@@ -103,10 +118,10 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -103,10 +118,10 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
ngram_free = True ngram_free = True
# find each max n-grams and check dictionary # find each max n-grams and check dictionary
for i in range(len(words) - args.ngram_size + 1): for i in range(len(words) - args.max_ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\ check_ngram_free = check_and_clean_text(args, words[i:\
i+args.ngram_size], ngrams, text, positions[i], \ i+args.max_ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram)
# the seq is ngram free? if yes, break # the seq is ngram free? if yes, break
if not check_ngram_free: if not check_ngram_free:
...@@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -118,7 +133,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
for ngram_len, _ in ngrams_freq_sorted: for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\ check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \ i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram)
# same check as above # same check as above
if not check_ngram_free: if not check_ngram_free:
...@@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -130,16 +145,16 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
break break
# for the last max n-gram, check all the lower ngrams in it # for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.ngram_size > 0: if ngram_free and len(words) - args.max_ngram_size > 0:
# get the last words of the lax max ngram # get the last words of the lax max ngram
last_seq_words = words[(len(words) - args.ngram_size):len(words)] last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
last_seq_start_position = len(words) - args.ngram_size last_seq_start_position = len(words) - args.max_ngram_size
# check all n-grams lower than the max # check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted): for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already # ignore the max ngram as has been considered already
if ngram_len == args.ngram_size: if ngram_len == args.max_ngram_size:
continue continue
# find each ngram of ngram_len in max n-grams and check # find each ngram of ngram_len in max n-grams and check
...@@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -147,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
check_ngram_free = check_and_clean_text(args, \ check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\ last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \ positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf) text_buf_ngram_free, text_buf, local_ngram)
if not check_ngram_free: if not check_ngram_free:
ngram_free = False ngram_free = False
...@@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -157,34 +172,35 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
break break
# texts are ngram free # texts are ngram free
if ngram_free: if ngram_free and not args.get_ngram_freq_only:
text_buf_ngram_free.append(text) text_buf_ngram_free.append(text)
# check if the text has only been trimmed # check if the text has only been trimmed
trimmed = 0 trimmed = 0
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \ if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
len(myjson[key]): len(text_buf_ngram_free[0]) < len(myjson[key]):
trimmed = 1 trimmed = 1
return text_buf_ngram_free, trimmed return text_buf_ngram_free, trimmed, local_ngram
# insert word sequence into dictionary # insert word sequence into dictionary
def insert_dict(words, ngrams, pos): def insert_dict(words, ngrams, pos):
seq = " ".join(words) seq = " ".join(words)
if seq not in ngrams: if seq not in ngrams:
ngrams[seq] = pos ngrams[seq] = 0
#ngrams[seq] = pos
# insert each ngram from text into the ngrams dictionary # insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams): def compute_ngrams_insert_dict(args, text, ngrams):
words, positions = get_words(text) words, positions = get_words(text)
if len(words) == 0: if len(words) < args.min_ngram_size:
return return
if len(words) < args.ngram_size: if len(words) < args.max_ngram_size:
insert_dict(words, ngrams, positions[0]) insert_dict(words, ngrams, positions[0])
for i in range(len(words) - args.ngram_size+1): for i in range(len(words) - args.max_ngram_size+1):
insert_dict(words[i:i+args.ngram_size], ngrams, positions[i]) insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
# Build ngrams for the lambada dataset # Build ngrams for the lambada dataset
...@@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams): ...@@ -203,6 +219,7 @@ def process_task_lambda(args, task_file, ngrams):
# Build ngrams for the dataset of the given task # Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams): def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets')) print(' reading from {} and computing ngrams'.format('import datasets'))
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True) print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets # using validation/test data from datasets
...@@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams): ...@@ -253,39 +270,7 @@ def process_task(args, task_name, ngrams):
print(" After task {} entities in ngrams {}, added {}".format(task_name, \ print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True) len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
if __name__ == '__main__': def compute_tasks_ngrams(args, ngrams):
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
# Default dedup values
parser.add_argument('--ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
# Build ngrams
ngrams = {}
start_time = time.time() start_time = time.time()
for _, task_name in enumerate(args.tasks): for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True) print('Task: {}'.format(task_name), flush=True)
...@@ -294,10 +279,10 @@ if __name__ == '__main__': ...@@ -294,10 +279,10 @@ if __name__ == '__main__':
process_task_lambda(args, args.lambada_path, ngrams) process_task_lambda(args, args.lambada_path, ngrams)
else: else:
process_task(args, task_name, ngrams) process_task(args, task_name, ngrams)
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
start_time), flush=True)
print(" Taken time {:.2f}".format(time.time() - start_time), flush=True) def compute_ngram_freq_sorted(args, ngrams):
# get the range of the size of the ngrams
ngrams_freq = {} ngrams_freq = {}
for ngram_key in ngrams.keys(): for ngram_key in ngrams.keys():
length = len(ngram_key.split()) length = len(ngram_key.split())
...@@ -309,33 +294,74 @@ if __name__ == '__main__': ...@@ -309,33 +294,74 @@ if __name__ == '__main__':
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True) ngrams_freq_sorted) -1 ][0]), flush=True)
return ngrams_freq_sorted
id_prefix = '-'.join(args.tasks[::2]) def get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted):
print('Reading file {} and deduping n-grams'.format(args.dedup_dataset)) start_time = time.time()
# get the ngrams frequency
args.get_ngram_freq_only = True
# Open the large file to process in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
counter = 0 counter = 0
for _, _, local_ngram in free_ngrams_abt:
counter += 1
if counter % 1000 == 0:
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
for local_key in local_ngram:
if local_key in ngrams:
ngrams[local_key] += 1
local_ngram = {}
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
start_time), flush=True)
pool.close()
pool.join()
start_time = time.time() start_time = time.time()
counter_threshold = 0
# Get ngram above theadhold
for local_key, local_val in ngrams.items():
if ngrams[local_key] > args.key_threshold:
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
counter_threshold += 1
ngrams_above_threshold[local_key] = 1
print(' Ngrams above threshold {}'.format(counter_threshold), flush=True)
fin.close()
if args.output is not None: def clean_ngrams_above_threshold(args, ngrams_above_threshold, dedup_file, \
out_f = open(args.output, 'wb') dedup_key):
splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0 start_time = time.time()
# Now actually filter the dataset
args.get_ngram_freq_only = False
id_prefix = '-'.join(args.tasks[::2])
assert len(args.dedup_dataset) == 2 # get the range of the size of the ngrams
dedup_file = args.dedup_dataset[0] ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_above_threshold)
dedup_key = args.dedup_dataset[1]
# Setup multi-processing. # Open the large file to process in parallel
num_workers = 40 num_workers = 40
fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers) pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, args=args, key=dedup_key, ngrams=ngrams, \ fin = open(dedup_file, 'r', encoding='utf-8')
ngrams_freq_sorted=ngrams_freq_sorted) free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams_above_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
free_ngrams = pool.imap(free_ngram_x, fin, 25) out_f = open(args.output, 'wb')
for text_buf_ngram_free, trimmed in free_ngrams: counter = splitted = ignored = split_mt_thld = trimmed_count = 0
for text_buf_ngram_free, trimmed, _ in free_ngrams_clean:
counter += 1 counter += 1
try: try:
...@@ -361,18 +387,95 @@ if __name__ == '__main__': ...@@ -361,18 +387,95 @@ if __name__ == '__main__':
out_f.write('\n'.encode('utf-8')) out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0: if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'. print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True) format(counter, time.time() - start_time), flush=True)
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
if args.output is not None: print(' [final]> processed {} documents in {:.2f} seconds ...'.
out_f.close() format(counter, time.time() - start_time), flush=True)
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
, flush=True)
pool.close()
pool.join()
out_f.close()
fin.close() fin.close()
print("Deduped file written to: {}".format(args.output), flush=True) if __name__ == '__main__':
print("Total docs {} splitted {} ignored {} docs with many splits {}"\
" trimmed {}".format(counter, splitted, ignored, split_mt_thld, \ # we use 13-grams, any text less than 200 characters got removed
trimmed_count), flush=True) # any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
# Default dedup values
parser.add_argument('--max-ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--min-ngram-size', type=int, default=8,
help='Minimum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--key-threshold', type=int, default=10,
help='Number of keys to consider as threshold')
parser.add_argument('--save-dictionary', type=str, default=None,
help='Save the dictionary')
parser.add_argument('--load-dictionary', type=str, default=None,
help='Load the dictionary')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing
num_workers = 40
if args.load_dictionary is None:
# Build ngrams
ngrams = {}
compute_tasks_ngrams(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
# get ngram freq from large file in parallel
# get ngrams above threshold
ngrams_above_threshold = {}
get_ngrams_above_threshold(args, ngrams, ngrams_above_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted)
# save the dictionary if needed
if args.save_dictionary is not None:
with open(args.save_dictionary, 'wb') as save_dict_handle:
pickle.dump(ngrams_above_threshold, save_dict_handle)
else:
with open(args.load_dictionary, 'rb') as load_dict_handle:
ngrams_above_threshold = pickle.load(load_dict_handle)
# filter the large file
if args.output is not None:
clean_ngrams_above_threshold(args, ngrams_above_threshold, \
dedup_file, dedup_key)
print('done :-)') print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment