"server/vscode:/vscode.git/clone" did not exist on "8511669cb29115bdf0bc2da5328e69d041030996"
Commit 6013e23c authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Dedup for other tasks added

parent b08b5edc
......@@ -162,7 +162,7 @@ def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# check if the text has only been trimmed
trimmed = 0
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) == \
if len(text_buf_ngram_free) == 1 and len(text_buf_ngram_free[0]) < \
len(myjson[key]):
trimmed = 1
......@@ -201,21 +201,57 @@ def process_task_lambda(args, task_file, ngrams):
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the squad v2 dataset
def process_task_squad(args, ngrams):
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
# using squad data from datasets
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
from datasets import load_dataset
squad_v2 = load_dataset('squad_v2', split='validation')
for line in squad_v2:
entities_in_ngrams = len(ngrams)
# load the dataset
if task_name == 'squad':
dataset = load_dataset('squad_v2', split='validation')
elif task_name == 'natural_questions':
dataset = load_dataset('natural_questions', split='validation')
elif task_name == 'triviaqa':
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
elif task_name == 'webqa':
dataset = load_dataset('web_questions', split='test')
elif task_name == 'race':
dataset = load_dataset('race', 'all', split='test')
elif task_name == 'drop':
dataset = load_dataset('drop', split='validation')
elif task_name == 'coqa':
dataset = load_dataset('coqa', split='validation')
elif task_name == 'piqa':
dataset = load_dataset('piqa', split='test')
else:
print("Invalid task name: {}".format(task_name), flush=True)
return
# read the dataset and add to ngrams
for line in dataset:
try:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'natural_questions':
text = line['question']['text']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'coqa':
all_questions = line['questions']
for question in all_questions:
compute_ngrams_insert_dict(args, question, ngrams)
elif task_name == 'piqa':
text = line['goal']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
if __name__ == '__main__':
......@@ -227,7 +263,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad]')
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
......@@ -249,13 +286,16 @@ if __name__ == '__main__':
# Build ngrams
ngrams = {}
start_time = time.time()
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
if task_name == 'squad':
process_task_squad(args, ngrams)
else:
process_task(args, task_name, ngrams)
print(" Taken time {:.2f}".format(time.time() - start_time), flush=True)
# get the range of the size of the ngrams
ngrams_freq = {}
......@@ -263,8 +303,8 @@ if __name__ == '__main__':
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[1])
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
......@@ -276,7 +316,10 @@ if __name__ == '__main__':
counter = 0
start_time = time.time()
out_f = open(args.output, 'wb')
if args.output is not None:
out_f = open(args.output, 'wb')
splitted, ignored, split_mt_thld, trimmed_count = 0, 0, 0, 0
assert len(args.dedup_dataset) == 2
......@@ -299,7 +342,7 @@ if __name__ == '__main__':
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1)
splitted += 1
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
......@@ -307,14 +350,15 @@ if __name__ == '__main__':
text_buf_ngram_free = []
split_mt_thld += 1
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
+ '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if args.output is not None:
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(\
counter)) + '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
......@@ -322,7 +366,9 @@ if __name__ == '__main__':
except Exception as e:
print('Error:', e)
out_f.close()
if args.output is not None:
out_f.close()
fin.close()
print("Deduped file written to: {}".format(args.output), flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment