Commit 6013e23c authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Dedup for other tasks added

parent b08b5edc
...@@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted): ...@@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# check if the text has only been trimmed # check if the text has only been trimmed
trimmed = 0 trimmed = 0
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \ if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \
len(myjson[key]): len(myjson[key]):
trimmed = 1 trimmed = 1
...@@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams): ...@@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams):
print(" Entities in ngrams {}".format(len(ngrams)), flush=True) print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the squad v2 dataset # Build ngrams for the dataset of the given task
def process_task_squad(args, ngrams): def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets')) print(' reading from {} and computing ngrams'.format('import datasets'))
# using squad data from datasets print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
from datasets import load_dataset from datasets import load_dataset
squad_v2 = load_dataset('squad_v2', split='validation')
for line in squad_v2: entities_in_ngrams = len(ngrams)
# load the dataset
if task_name == 'squad':
dataset = load_dataset('squad_v2', split='validation')
elif task_name == 'natural_questions':
dataset = load_dataset('natural_questions', split='validation')
elif task_name == 'triviaqa':
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
elif task_name == 'webqa':
dataset = load_dataset('web_questions', split='test')
elif task_name == 'race':
dataset = load_dataset('race', 'all', split='test')
elif task_name == 'drop':
dataset = load_dataset('drop', split='validation')
elif task_name == 'coqa':
dataset = load_dataset('coqa', split='validation')
elif task_name == 'piqa':
dataset = load_dataset('piqa', split='test')
else:
print("Invalid task name: {}".format(task_name), flush=True)
return
# read the dataset and add to ngrams
for line in dataset:
try: try:
text = line['question'] if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
compute_ngrams_insert_dict(args, text, ngrams) text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'natural_questions':
text = line['question']['text']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'coqa':
all_questions = line['questions']
for question in all_questions:
compute_ngrams_insert_dict(args, question, ngrams)
elif task_name == 'piqa':
text = line['goal']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -227,7 +263,8 @@ if __name__ == '__main__': ...@@ -227,7 +263,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \ parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently ' help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad]') ' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None, parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path') help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None, parser.add_argument('--dedup-dataset', nargs = '*', default=None,
...@@ -249,13 +286,16 @@ if __name__ == '__main__': ...@@ -249,13 +286,16 @@ if __name__ == '__main__':
# Build ngrams # Build ngrams
ngrams = {} ngrams = {}
start_time = time.time()
for _, task_name in enumerate(args.tasks): for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True) print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada': if task_name == 'lambada':
assert args.lambada_path is not None assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams) process_task_lambda(args, args.lambada_path, ngrams)
if task_name == 'squad': else:
process_task_squad(args, ngrams) process_task(args, task_name, ngrams)
print(" Taken time {:.2f}".format(time.time() - start_time), flush=True)
# get the range of the size of the ngrams # get the range of the size of the ngrams
ngrams_freq = {} ngrams_freq = {}
...@@ -263,8 +303,8 @@ if __name__ == '__main__': ...@@ -263,8 +303,8 @@ if __name__ == '__main__':
length = len(ngram_key.split()) length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \ ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1 ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1])
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True) print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\ print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\ len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
...@@ -276,7 +316,10 @@ if __name__ == '__main__': ...@@ -276,7 +316,10 @@ if __name__ == '__main__':
counter = 0 counter = 0
start_time = time.time() start_time = time.time()
out_f = open(args.output, 'wb')
if args.output is not None:
out_f = open(args.output, 'wb')
splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0 splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0
assert len(args.dedup_dataset) == 2 assert len(args.dedup_dataset) == 2
...@@ -299,7 +342,7 @@ if __name__ == '__main__': ...@@ -299,7 +342,7 @@ if __name__ == '__main__':
trimmed_count += trimmed trimmed_count += trimmed
if len(text_buf_ngram_free) > 1: if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1) splitted += 1
if len(text_buf_ngram_free) == 0: if len(text_buf_ngram_free) == 0:
ignored += 1 ignored += 1
# more than 10 splits ignored # more than 10 splits ignored
...@@ -307,14 +350,15 @@ if __name__ == '__main__': ...@@ -307,14 +350,15 @@ if __name__ == '__main__':
text_buf_ngram_free = [] text_buf_ngram_free = []
split_mt_thld += 1 split_mt_thld += 1
for i in range(len(text_buf_ngram_free)): if args.output is not None:
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \ for i in range(len(text_buf_ngram_free)):
+ '-{:010d}'.format(int(i)) split_id_string = id_prefix + '-{:010d}'.format(int(\
outjson = json.dumps({"text":text_buf_ngram_free[i], counter)) + '-{:010d}'.format(int(i))
id_prefix+"_split_id":split_id_string}, outjson = json.dumps({"text":text_buf_ngram_free[i],
ensure_ascii=False) id_prefix+"_split_id":split_id_string},
out_f.write(outjson.encode('utf-8')) ensure_ascii=False)
out_f.write('\n'.encode('utf-8')) out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0: if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'. print(' [search]> processed {} documents in {:.2f} seconds ...'.
...@@ -322,7 +366,9 @@ if __name__ == '__main__': ...@@ -322,7 +366,9 @@ if __name__ == '__main__':
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
out_f.close() if args.output is not None:
out_f.close()
fin.close() fin.close()
print("Deduped file written to: {}".format(args.output), flush=True) print("Deduped file written to: {}".format(args.output), flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment