Commit 8c9c145d authored by zihanl's avatar zihanl
Browse files

update preprocessing.py

parent 91a80bd1
...@@ -13,16 +13,20 @@ def get_args(): ...@@ -13,16 +13,20 @@ def get_args():
parser.add_argument("--func", type=str, default=None, parser.add_argument("--func", type=str, default=None,
help="choose to run which function") help="choose to run which function")
parser.add_argument("--input_file", type=str, default=None, parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file") help="path of the input file")
parser.add_argument("--knowledge_file", type=str, default=None, parser.add_argument("--processed_file", type=str, default=None,
help="path of the knowledge file") help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None, parser.add_argument("--test_file", type=str, default=None,
help="path of the test file") help="path of the test file")
parser.add_argument("--train_file", type=str, default=None, parser.add_argument("--train_file", type=str, default=None,
help="path of the train file") help="path of the train file")
parser.add_argument("--output_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--model_file", type=str, default=None, parser.add_argument("--model_file", type=str, default=None,
help="path of the model file") help="path of the model file")
parser.add_argument("--data_type", type=str, default=None, parser.add_argument("--data_type", type=str, default=None,
...@@ -34,19 +38,22 @@ def get_args(): ...@@ -34,19 +38,22 @@ def get_args():
return args return args
def process_wow_dataset(input_file, output_file): def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
""" """
This is a function used for processing the wizard of wikipedia (wow) dataset This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format: Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
print("> Loading data from %s" % input_file) print("> Loading data from %s" % raw_file)
with open(input_file, "r") as fr: with open(raw_file, "r") as fr:
dialog_data = json.load(fr) dialog_data = json.load(fr)
print("> Processing data ...") print("> Processing data ...")
with open(output_file, "w") as fw: fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
for i, sample in enumerate(tqdm(dialog_data)): for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample # get all the dialog data for a single sample
dialog = sample["dialog"] dialog = sample["dialog"]
...@@ -86,26 +93,45 @@ def process_wow_dataset(input_file, output_file): ...@@ -86,26 +93,45 @@ def process_wow_dataset(input_file, output_file):
else: else:
topic = sample["chosen_topic"] topic = sample["chosen_topic"]
# write to the output file knowledge = checked_sentence
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \ response = text
checked_sentence + "\t" + text + "\n") # write to the output files
fproc.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
context.append(text) context.append(text)
else: else:
assert "apprentice" in speaker assert "apprentice" in speaker
context.append(text) context.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(input_file, output_file): def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
""" """
This is a function used for processing the wizard of internet (woi) dataset This is a function used for processing the wizard of internet (woi) dataset
Expected processed format: Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
print("> Processing %s" % input_file) print("> Processing %s" % raw_file)
with open(output_file, "w") as fw: fproc = open(processed_file, "w")
with open(input_file, "r") as fr: fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)): for i, line in tqdm(enumerate(fr)):
line = line.strip() line = line.strip()
item_dict = json.loads(line) item_dict = json.loads(line)
...@@ -173,10 +199,16 @@ def process_woi_dataset(input_file, output_file): ...@@ -173,10 +199,16 @@ def process_woi_dataset(input_file, output_file):
response = response.replace("\n", "").replace("\r", \ response = response.replace("\n", "").replace("\r", \
"").replace("\t", "") "").replace("\t", "")
# write to the ouput file
if topic != "no_topic": if topic != "no_topic":
fw.write(topic + "\t" + dialog_context + "\t" + \ # write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n") knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response) turn_list.append(response)
...@@ -187,6 +219,12 @@ def process_woi_dataset(input_file, output_file): ...@@ -187,6 +219,12 @@ def process_woi_dataset(input_file, output_file):
else: else:
assert action == "SearchAgent => Wizard" assert action == "SearchAgent => Wizard"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type): def get_database(test_datapath, train_datapath, data_type):
"""Get the database by topics""" """Get the database by topics"""
...@@ -465,7 +503,6 @@ def prompt_selection_for_response_generation(input_path, output_path, seed): ...@@ -465,7 +503,6 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
prompt_example_list.append(prompt_example) prompt_example_list.append(prompt_example)
# shuffle the prompt examples # shuffle the prompt examples
print("length: %d" % len(prompt_example_list))
np.random.shuffle(prompt_example_list) np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path) print("> writing to %s" % output_path)
...@@ -476,17 +513,17 @@ def prompt_selection_for_response_generation(input_path, output_path, seed): ...@@ -476,17 +513,17 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
f.write(example + "\n") f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knowledge_file, output_file): def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation""" """Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knowledge_file) print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list # get the knowledge list
with open(knowledge_file, "r") as f: with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines() knowledge_list = f.readlines()
print("> Processing ...") print("> Processing ...")
with open(test_file, "r") as fr: with open(test_file, "r") as fr:
with open(output_file, "w") as fw: with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)): for line_num, line in enumerate(tqdm(fr)):
line = line.strip() line = line.strip()
splits = line.split("\t") splits = line.split("\t")
...@@ -508,20 +545,20 @@ if __name__ == "__main__": ...@@ -508,20 +545,20 @@ if __name__ == "__main__":
args = get_args() args = get_args()
if args.func == "process_wow_dataset": if args.func == "process_wow_dataset":
process_wow_dataset(args.input_file, args.output_file) process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset": elif args.func == "process_woi_dataset":
process_woi_dataset(args.input_file, args.output_file) process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts": elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation( prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file, args.test_file, args.train_file, args.model_file,
args.output_file, args.data_type) args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts": elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation( prompt_selection_for_response_generation(
args.train_file, args.output_file, args.seed) args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input": elif args.func == "prepare_input":
prepare_input_for_response_generation( prepare_input_for_response_generation(
args.test_file, args.knowledge_file, args.output_file) args.test_file, args.knwl_gen_file, args.processed_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment