Commit 8c9c145d authored by zihanl's avatar zihanl
Browse files

update preprocessing.py

parent 91a80bd1
......@@ -13,16 +13,20 @@ def get_args():
parser.add_argument("--func", type=str, default=None,
help="choose to run which function")
parser.add_argument("--input_file", type=str, default=None,
parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file")
parser.add_argument("--knowledge_file", type=str, default=None,
help="path of the knowledge file")
parser.add_argument("--processed_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None,
help="path of the test file")
parser.add_argument("--train_file", type=str, default=None,
help="path of the train file")
parser.add_argument("--output_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--model_file", type=str, default=None,
help="path of the model file")
parser.add_argument("--data_type", type=str, default=None,
......@@ -34,158 +38,192 @@ def get_args():
return args
def process_wow_dataset(input_file, output_file):
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
print("> Loading data from %s" % input_file)
with open(input_file, "r") as fr:
print("> Loading data from %s" % raw_file)
with open(raw_file, "r") as fr:
dialog_data = json.load(fr)
print("> Processing data ...")
with open(output_file, "w") as fw:
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample
dialog = sample["dialog"]
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample
dialog = sample["dialog"]
context = []
for j, turn in enumerate(dialog):
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
context = []
for j, turn in enumerate(dialog):
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
if j == 0:
# first turn
context.append(text)
continue
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
if j == 0:
# first turn
context.append(text)
continue
assert len(checked_sentence) <= 1
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
checked_sentence = "no_passages_used"
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
checked_sentence = "no_passages_used"
if len(checked_passage) == 1:
checked_passage = checked_passage[0]
else:
checked_passage = "no_passages_used"
if len(checked_passage) == 1:
checked_passage = checked_passage[0]
else:
checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
knowledge = checked_sentence
response = text
# write to the output files
fproc.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
# write to the output file
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
checked_sentence + "\t" + text + "\n")
context.append(text)
context.append(text)
else:
assert "apprentice" in speaker
context.append(text)
else:
assert "apprentice" in speaker
context.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(input_file, output_file):
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
print("> Processing %s" % raw_file)
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
line = line.strip()
item_dict = json.loads(line)
item_dict = item_dict.values()
assert len(item_dict) == 1
item_dict = list(item_dict)[0]
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = []
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
print("> Processing %s" % input_file)
with open(output_file, "w") as fw:
with open(input_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
line = line.strip()
item_dict = json.loads(line)
item_dict = item_dict.values()
assert len(item_dict) == 1
item_dict = list(item_dict)[0]
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = []
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used
topic = "no_topic"
sent_list = ["no_passages_used"]
else:
# assert search_text != ""
topic = search_text
sent_list = []
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
sent_list.append(c)
if len(sent_list) == 0:
topic = "no_topic"
sent_list = ["no_passages_used"]
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
knwl_sent = sent_list[0]
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
# write to the ouput file
if topic != "no_topic":
fw.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used
topic = "no_topic"
sent_list = ["no_passages_used"]
else:
assert action == "SearchAgent => Wizard"
# assert search_text != ""
topic = search_text
sent_list = []
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
sent_list.append(c)
if len(sent_list) == 0:
topic = "no_topic"
sent_list = ["no_passages_used"]
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
knwl_sent = sent_list[0]
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
if topic != "no_topic":
# write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text']
turn_list.append(turn)
else:
assert action == "SearchAgent => Wizard"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type):
......@@ -465,7 +503,6 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
prompt_example_list.append(prompt_example)
# shuffle the prompt examples
print("length: %d" % len(prompt_example_list))
np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path)
......@@ -476,17 +513,17 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knowledge_file, output_file):
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knowledge_file)
print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list
with open(knowledge_file, "r") as f:
with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines()
print("> Processing ...")
with open(test_file, "r") as fr:
with open(output_file, "w") as fw:
with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
......@@ -508,20 +545,20 @@ if __name__ == "__main__":
args = get_args()
if args.func == "process_wow_dataset":
process_wow_dataset(args.input_file, args.output_file)
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset":
process_woi_dataset(args.input_file, args.output_file)
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file,
args.output_file, args.data_type)
args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation(
args.train_file, args.output_file, args.seed)
args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input":
prepare_input_for_response_generation(
args.test_file, args.knowledge_file, args.output_file)
args.test_file, args.knwl_gen_file, args.processed_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment