Commit 8c9c145d authored by zihanl's avatar zihanl
Browse files

update preprocessing.py

parent 91a80bd1
...@@ -13,16 +13,20 @@ def get_args(): ...@@ -13,16 +13,20 @@ def get_args():
parser.add_argument("--func", type=str, default=None, parser.add_argument("--func", type=str, default=None,
help="choose to run which function") help="choose to run which function")
parser.add_argument("--input_file", type=str, default=None, parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file") help="path of the input file")
parser.add_argument("--knowledge_file", type=str, default=None, parser.add_argument("--processed_file", type=str, default=None,
help="path of the knowledge file") help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None, parser.add_argument("--test_file", type=str, default=None,
help="path of the test file") help="path of the test file")
parser.add_argument("--train_file", type=str, default=None, parser.add_argument("--train_file", type=str, default=None,
help="path of the train file") help="path of the train file")
parser.add_argument("--output_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--model_file", type=str, default=None, parser.add_argument("--model_file", type=str, default=None,
help="path of the model file") help="path of the model file")
parser.add_argument("--data_type", type=str, default=None, parser.add_argument("--data_type", type=str, default=None,
...@@ -34,158 +38,192 @@ def get_args(): ...@@ -34,158 +38,192 @@ def get_args():
return args return args
def process_wow_dataset(input_file, output_file): def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
""" """
This is a function used for processing the wizard of wikipedia (wow) dataset This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format: Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
print("> Loading data from %s" % input_file) print("> Loading data from %s" % raw_file)
with open(input_file, "r") as fr: with open(raw_file, "r") as fr:
dialog_data = json.load(fr) dialog_data = json.load(fr)
print("> Processing data ...") print("> Processing data ...")
with open(output_file, "w") as fw: fproc = open(processed_file, "w")
for i, sample in enumerate(tqdm(dialog_data)): fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
# get all the dialog data for a single sample fresp = open(resp_ref_file, "w") if resp_ref_file else None
dialog = sample["dialog"]
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample
dialog = sample["dialog"]
context = []
for j, turn in enumerate(dialog):
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
context = [] if j == 0:
for j, turn in enumerate(dialog): # first turn
text = turn["text"] context.append(text)
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")): continue
text = text + "."
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
if j == 0: assert len(checked_sentence) <= 1
# first turn
context.append(text)
continue
speaker = turn["speaker"].lower() # get the ground truth knowledge
if "wizard" in speaker: if len(checked_sentence) > 0:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge checked_sentence = checked_sentence[0]
checked_passage = list(turn["checked_passage"].values()) # topic else:
checked_sentence = "no_passages_used"
assert len(checked_sentence) <= 1
# get the ground truth knowledge if len(checked_passage) == 1:
if len(checked_sentence) > 0: checked_passage = checked_passage[0]
checked_sentence = checked_sentence[0] else:
else: checked_passage = "no_passages_used"
checked_sentence = "no_passages_used"
if len(checked_passage) == 1: # get the topic
checked_passage = checked_passage[0] if checked_passage != "no_passages_used":
else: topic = checked_passage
checked_passage = "no_passages_used" else:
topic = sample["chosen_topic"]
knowledge = checked_sentence
response = text
# write to the output files
fproc.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
# get the topic context.append(text)
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
# write to the output file
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
checked_sentence + "\t" + text + "\n")
context.append(text)
else: else:
assert "apprentice" in speaker assert "apprentice" in speaker
context.append(text) context.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(input_file, output_file): def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
""" """
This is a function used for processing the wizard of internet (woi) dataset This is a function used for processing the wizard of internet (woi) dataset
Expected processed format: Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
print("> Processing %s" % raw_file)
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
line = line.strip()
item_dict = json.loads(line)
item_dict = item_dict.values()
assert len(item_dict) == 1
item_dict = list(item_dict)[0]
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = []
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
print("> Processing %s" % input_file) if action == "Wizard => SearchAgent":
with open(output_file, "w") as fw: search_text = item['text']
with open(input_file, "r") as fr:
for i, line in tqdm(enumerate(fr)): elif action == "Wizard => Apprentice":
line = line.strip() if len(turn_list) == 0:
item_dict = json.loads(line)
item_dict = item_dict.values()
assert len(item_dict) == 1
item_dict = list(item_dict)[0]
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = []
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used
topic = "no_topic"
sent_list = ["no_passages_used"]
else:
# assert search_text != ""
topic = search_text
sent_list = []
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
sent_list.append(c)
if len(sent_list) == 0:
topic = "no_topic"
sent_list = ["no_passages_used"]
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
knwl_sent = sent_list[0]
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
# write to the ouput file
if topic != "no_topic":
fw.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text'] turn = item['text']
turn_list.append(turn) turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used
topic = "no_topic"
sent_list = ["no_passages_used"]
else: else:
assert action == "SearchAgent => Wizard" # assert search_text != ""
topic = search_text
sent_list = []
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
sent_list.append(c)
if len(sent_list) == 0:
topic = "no_topic"
sent_list = ["no_passages_used"]
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
knwl_sent = sent_list[0]
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
if topic != "no_topic":
# write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text']
turn_list.append(turn)
else:
assert action == "SearchAgent => Wizard"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type): def get_database(test_datapath, train_datapath, data_type):
...@@ -465,7 +503,6 @@ def prompt_selection_for_response_generation(input_path, output_path, seed): ...@@ -465,7 +503,6 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
prompt_example_list.append(prompt_example) prompt_example_list.append(prompt_example)
# shuffle the prompt examples # shuffle the prompt examples
print("length: %d" % len(prompt_example_list))
np.random.shuffle(prompt_example_list) np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path) print("> writing to %s" % output_path)
...@@ -476,17 +513,17 @@ def prompt_selection_for_response_generation(input_path, output_path, seed): ...@@ -476,17 +513,17 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
f.write(example + "\n") f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knowledge_file, output_file): def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation""" """Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knowledge_file) print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list # get the knowledge list
with open(knowledge_file, "r") as f: with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines() knowledge_list = f.readlines()
print("> Processing ...") print("> Processing ...")
with open(test_file, "r") as fr: with open(test_file, "r") as fr:
with open(output_file, "w") as fw: with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)): for line_num, line in enumerate(tqdm(fr)):
line = line.strip() line = line.strip()
splits = line.split("\t") splits = line.split("\t")
...@@ -508,20 +545,20 @@ if __name__ == "__main__": ...@@ -508,20 +545,20 @@ if __name__ == "__main__":
args = get_args() args = get_args()
if args.func == "process_wow_dataset": if args.func == "process_wow_dataset":
process_wow_dataset(args.input_file, args.output_file) process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset": elif args.func == "process_woi_dataset":
process_woi_dataset(args.input_file, args.output_file) process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts": elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation( prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file, args.test_file, args.train_file, args.model_file,
args.output_file, args.data_type) args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts": elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation( prompt_selection_for_response_generation(
args.train_file, args.output_file, args.seed) args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input": elif args.func == "prepare_input":
prepare_input_for_response_generation( prepare_input_for_response_generation(
args.test_file, args.knowledge_file, args.output_file) args.test_file, args.knwl_gen_file, args.processed_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment