Commit 1dacc318 authored by zihanl's avatar zihanl
Browse files

delete some args and update comments

parent dfd43f49
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
We present the steps to run our multi-stage dialogue prompting (MSDP), as well as the finetuning-based models (i.e., finetuning-based knowledge generation (FKG) and finetuning-based coversation model (FCM)).
We present the steps to run our multi-stage dialogue prompting (MSDP), as well as the finetuning-based baselines (i.e., finetuning-based knowledge generation (FKG) and finetuning-based coversation model (FCM)).
## Multi-Stage Dialogue Prompting (MSDP)
......@@ -10,18 +10,18 @@ We present the steps to run our multi-stage dialogue prompting (MSDP), as well a
2. Data Processing: We provide script ```tasks/knwl_dialo/scripts/data_processing.sh``` to process the data.
### Knowledge Generation
1. The script ```tasks/knwl_dialo/scripts/prompt_knwl_gen.sh``` provides an example for how to perform the knowledge generation prompting.
1. The script ```tasks/knwl_dialo/scripts/prompt_knwl_gen.sh``` provides an example for how to perform the first-stage prompting for the knowledge generation.
2. The F1/FK1 score can be evaluated through ```tasks/knwl_dialo/scripts/eval_generation.sh```. Other automatic metrics (i.e., BLEU, METEOR, and ROUGE-L) follow the [nlg-eval](https://github.com/Maluuba/nlg-eval).
### Response Generation
1. Prepare the input file for the response generation (based on the previously generated knowledge file):
2. The script ```tasks/knwl_dialo/scripts/prompt_resp_gen.sh``` provides an example for how to perform the response generation prompting.
1. The script ```tasks/knwl_dialo/scripts/prep_respgen.sh``` helps to prepare the input file for the response generation (based on the previously generated knowledge file).
2. The script ```tasks/knwl_dialo/scripts/prompt_resp_gen.sh``` provides an example for how to perform the second-stage prompting for the response generation.
3. The automatic evaluations are the same as mentioned aboved for the knowledge generation.
## Finetuning-based Models
## Finetuning-based Baselines
### FKG
The script ```tasks/knwl_dialo/scripts/finetune_knwl_gen.sh``` provides an example for how to train a finetuning-based knowledge generation (FKG) model.
The script ```tasks/knwl_dialo/scripts/finetune_knwl_gen.sh``` provides an example for how to train a finetuning-based knowledge generation model (FKG).
### FCM
The script ```tasks/knwl_dialo/scripts/finetune_resp_gen.sh``` provides an example for how to train a finetuning-based conversational model (FCM).
......
"""Model evaluation"""
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
......@@ -17,27 +19,28 @@ from tasks.knwl_dialo.metrics import F1Metric
from tqdm import tqdm
def test_dataset_provider():
"""Build the test dataset for dialog/control module"""
"""Build the test dataset"""
args = get_args()
print_rank_0('> building the test dataset for %s module ...' \
% args.train_module)
% args.module)
if args.eval_prompting:
if args.prompt_type != "":
print_rank_0('> evaluating ppl for prompting')
test_ds = build_test_dataset_for_prompting(
test_data_path=args.test_data_path,
prompt_file=args.prompt_file,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
module=args.module,
max_seq_len=args.seq_length,
num_prompt_examples=args.num_prompt_examples,
three_turns=args.three_turns,
dynamic_prompt=args.dynamic_prompt)
else:
print_rank_0('> evaluating ppl for finetuning')
test_ds = build_test_dataset(
test_data_path=args.test_data_path,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
module=args.module,
max_seq_len=args.seq_length,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
......@@ -45,7 +48,7 @@ def test_dataset_provider():
remove_ctrl_sent=args.remove_ctrl_sent)
print_rank_0("> finished creating the test dataset for %s module ..." \
% args.train_module)
% args.module)
print_rank_0('> test set size: %d' % len(test_ds))
args.eval_iters = len(test_ds) // args.global_batch_size
......@@ -68,6 +71,7 @@ def _build_test_iterator(test_dataset, task_collate_fn=None):
def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
"""Evaluating perplexity"""
args = get_args()
timers = get_timers()
......@@ -110,6 +114,7 @@ def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = []
print_rank_0('reading %s' % guess_file)
......
"""Dialogue Finetuning"""
"""Finetuning a pretrained language model for knowledge/response generation"""
import torch
from functools import partial
......@@ -42,7 +42,7 @@ def train_valid_datasets_provider():
train_data_path=args.train_data_path,
valid_data_path=args.test_data_path,
module=args.module,
max_seq_len=args.max_seq_len,
max_seq_len=args.seq_length,
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.module)
......@@ -135,30 +135,30 @@ def generate_samples_input_from_file(model):
context_count = 0
model.eval()
# start the generation process
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
# get the generation outputs
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
# write the generation to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
......@@ -194,6 +194,7 @@ def run_generation(model_provider):
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# run generation
generate_samples_input_from_file(model)
......@@ -201,6 +202,7 @@ def main():
args = get_args()
if "FINETUNE" in args.task:
# finetune
finetune(train_valid_datasets_provider, model_provider, \
forward_step=forward_step)
else:
......
......@@ -26,7 +26,6 @@ def normalize_answer(s):
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
# TODO: this could almost certainly be faster with a regex \s+ -> ' '
s = ' '.join(s.split())
return s
......
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import argparse
from nltk import word_tokenize
from tqdm import tqdm
import numpy as np
import json
def get_params():
parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default="")
parser.add_argument("--input_file", type=str, default="")
parser.add_argument("--knowledge_file", type=str, default="")
parser.add_argument("--output_file", type=str, default="")
parser.add_argument("--func", type=str, default="",
help="choose to run which function")
parser.add_argument("--input_file", type=str, default="",
help="path of the input file")
parser.add_argument("--knowledge_file", type=str, default="",
help="path of the knowledge file")
parser.add_argument("--test_file", type=str, default="",
help="path of the test file")
parser.add_argument("--train_file", type=str, default="",
help="path of the train file")
parser.add_argument("--output_file", type=str, default="",
help="path of the output file")
parser.add_argument("--model_file", type=str, default="",
help="path of the model file")
parser.add_argument("--seed", type=int, default=123456,
help="random seed")
params = parser.parse_args()
return params
......@@ -17,14 +33,17 @@ def get_params():
def process_wow_dataset(input_file, output_file):
"""
expected processed format:
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
with open(input_file, "r") as fr:
dialog_data = json.load(fr)
with open(output_file, "w") as fw:
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample
dialog = sample["dialog"]
context = []
......@@ -46,6 +65,7 @@ def process_wow_dataset(input_file, output_file):
assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
......@@ -56,13 +76,15 @@ def process_wow_dataset(input_file, output_file):
else:
checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + checked_sentence + "\t" + text + "\n")
# write to the output file
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
checked_sentence + "\t" + text + "\n")
context.append(text)
else:
......@@ -71,6 +93,12 @@ def process_wow_dataset(input_file, output_file):
def process_woi_dataset(input_file, output_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
with open(output_path, "w") as fw:
with open(input_path, "r") as fr:
for i, line in tqdm(enumerate(fr)):
......@@ -93,19 +121,19 @@ def process_woi_dataset(input_file, output_file):
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
turn = item['text']
turn_list.append(turn)
continue
# get knowledge sentence
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used
topic = "no_topic"
......@@ -121,33 +149,29 @@ def process_woi_dataset(input_file, output_file):
for c, s in zip(content, select):
if s:
sent_list.append(c)
if len(sent_list) == 0:
topic = "no_topic"
sent_list = ["no_passages_used"]
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
knwl_sent = sent_list[0]
response = item['text']
topic = topic.replace("\n", "")
topic = topic.replace("\r", "")
topic = topic.replace("\t", "")
dialog_context = dialog_context.replace("\n", "")
dialog_context = dialog_context.replace("\r", "")
dialog_context = dialog_context.replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "")
knwl_sent = knwl_sent.replace("\r", "")
knwl_sent = knwl_sent.replace("\t", "")
response = response.replace("\n", "")
response = response.replace("\r", "")
response = response.replace("\t", "")
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
# write to the ouput file
if topic != "no_topic":
fw.write(topic + "\t" + dialog_context + "\t" + knwl_sent + "\t" + response + "\n")
fw.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
turn_list.append(response)
......@@ -159,6 +183,296 @@ def process_woi_dataset(input_file, output_file):
assert action == "SearchAgent => Wizard"
def get_database(test_datapath, train_datapath):
"""Get the database sorted by topics"""
# get test data topic list
print("> reading test data from %s" % test_datapath)
test_topics = {}
with open(test_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
test_topics[topic] = True
print("> reading data from %s" % train_datapath)
train_data_by_topic = {}
dialog_data_by_topic = {}
dialog_examples = []
with open(train_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2]
response = splits[3]
if knowledge == "no_passages_used":
continue
# get the instance
last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example
dialog_example = ""
dialog_example += "( " + topic + " )"
for turn in turns:
dialog_example += " "
dialog_example += turn
# check overlaps
if topic in test_topics:
if topic not in train_data_by_topic:
train_data_by_topic[topic] = [instance]
else:
train_data_by_topic[topic].append(instance)
if topic not in dialog_data_by_topic:
dialog_data_by_topic[topic] = [dialog_example]
else:
dialog_data_by_topic[topic].append(dialog_example)
# append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance))
return train_data_by_topic, dialog_data_by_topic, dialog_examples
emb_dict = {}
def select_prompts_based_on_similarity(
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
"""Select samples based on the similarity"""
with torch.no_grad():
# get the query embeddings
query_ids = tokenizer.encode(query)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate embeddings for the samples in the database
if topic in emb_dict:
example_embeddings = emb_dict[topic]
example_embeddings = example_embeddings.cuda()
else:
for idx, example in enumerate(dialog_list):
example_ids = tokenizer.encode(example)
example_ids = torch.LongTensor([example_ids]).cuda()
example_emb = encoder(input_ids=example_ids).pooler_output
if idx == 0:
example_embeddings = example_emb
else:
example_embeddings = torch.cat(
(example_embeddings, example_emb), dim=0)
emb_dict[topic] = example_embeddings.cpu()
# compare the similarity and select the topk samples
similarity_list = example_embeddings.matmul(query_emb)
_, indices = torch.topk(similarity_list, k=topk)
indices = indices.tolist()
indices = indices[::-1] # reverse the order
selected_prompts = []
for index in indices:
# index = index.item()
selected_prompts.append(prompt_list[index])
return selected_prompts
def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path):
"""Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath)
from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
print("> getting dialog embeddings")
with torch.no_grad():
for idx, example in tqdm(enumerate(dialog_examples)):
dialog = example[1]
dialog_ids = tokenizer.encode(dialog)
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
if idx == 0:
dialog_embeddings = dialog_emb
else:
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath)
count_out_of_list = 0
prompt_list_for_each_sample = []
with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
if topic not in train_data_by_topic:
count_out_of_list += 1
# calculate similarity
# get the query embedding
query_sent = ""
query_sent += "( " + topic + " )"
for turn in turns:
query_sent += " "
query_sent += turn
query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate the similarity
similarity_list = dialog_embeddings.matmul(query_emb)
_, indices = torch.sort(similarity_list)
indices = indices.tolist()
selected_topics = {}
selected_prompts = []
num_prompt = 0
for index in indices:
example = dialog_examples[index]
topic_temp = example[0]
if topic_temp not in selected_topics:
selected_topics[topic_temp] = True
selected_prompts.append(example[2])
num_prompt += 1
if num_prompt == 10:
break
# get the selected samples
example_list = selected_prompts[::-1]
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
else:
num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic]
# query_sent
query_sent = ""
query_sent += "( " + topic + " )"
for turn in turns:
query_sent += " "
query_sent += turn
dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == num_data_sample
# calculate the similarity
selected_examples = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample)
example_list = selected_examples
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
print("writing to %s" % output_prompt_path)
with open(output_prompt_path, "w") as f:
for instance in tqdm(prompt_list_for_each_sample):
json.dump(instance, f)
f.write("\n")
def prompt_selection_for_response_generation(input_path, output_path, seed):
"""Selecting prompts for the response generation"""
print("> Selecting prompts for the response generation")
print("> set random seed")
np.random.seed(seed)
prompt_example_list = []
print("> reading data from %s" % input_path)
with open(input_path, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
# get the topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")[-3:]
if knowledge == "no_passages_used":
continue
# calculate the overlap ratio
from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
response_token_list = response.split()
response_len = len(response_token_list)
num_overlap_token = 0
for token in response_token_list:
if token in knowledge_sent_token_dict:
num_overlap_token += 1
# filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue
prompt_example = ""
# add dialog context
prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + turns[-1] + " "
prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example)
print("> shuffle the prompt examples (total %d)" % len(prompt_example_list))
np.random.shuffle(prompt_example_list)
print("> Prompt example:")
print(prompt_example_list[0])
print("> writing to %s" % output_path)
with open(output_path, "w") as f:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for i in tqdm(range(20)):
example = prompt_example_list[i]
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knowledge_file, output_file):
"""Preparing inputs for the response generation"""
# get the knowledge list
with open(knowledge_file, "r") as f:
knowledge_list = f.readlines()
with open(test_file, "r") as fr:
with open(output_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
# prepare topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
response = splits[3]
knowledge = knowledge_list[line_num]
knowledge = knowledge.strip()
if "<|endoftext|>" in knowledge:
knowledge = knowledge.replace("<|endoftext|>", "")
# write to the output file
fw.write(topic + "\t" + dialog_context + "\t" \
+ knowledge + "\t" + response + "\n")
if __name__ == "__main__":
......@@ -168,3 +482,13 @@ if __name__ == "__main__":
elif params.func == "process_woi_dataset":
process_woi_dataset(params.input_file, params.output_file)
elif params.func == "get_prompts":
prompt_selection_for_knowledge_generation(
params.test_file, params.train_file, params.model_file, params.output_file)
prompt_selection_for_response_generation(
params.train_file, params.output_file, params.seed)
elif params.func == "prepare_input":
prepare_input_for_response_generation(
params.test_file, params.knowledge_file, params.output_file)
"""Prompting the pretrained language model to generate knowledge/response"""
import json
import torch
from nltk import word_tokenize
......@@ -27,7 +29,9 @@ def model_provider(pre_process=True, post_process=True):
def generate_samples_by_prompting_input_from_file(model):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args = get_args()
tokenizer = get_tokenizer()
......@@ -57,17 +61,17 @@ def generate_samples_by_prompting_input_from_file(model):
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
# get the prompt examples based on the key
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
# prompts are fixed for all test samples
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
......@@ -77,13 +81,14 @@ def generate_samples_by_prompting_input_from_file(model):
instance = instance.strip()
prompt += instance + " \n"
# only two prompt types (i.e., knowledge and response) are allowed
assert args.prompt_type in ["knowledge", "response"]
context_count = 0
model.eval()
# perform prompting
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
......@@ -92,16 +97,17 @@ def generate_samples_by_prompting_input_from_file(model):
control_codes = splits[0].split(" [CTRL] ")
topic = control_codes[0]
# first add the prompt into the inputs
if args.dynamic_prompt:
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
else:
raw_text = prompt
if args.prompt_type == "knowledge":
# construct inputs for knowledge generation
turns = splits[1].split(" [SEP] ")
context = turns[-1]
if " -> " in raw_text and " => " not in raw_text:
......@@ -110,11 +116,11 @@ def generate_samples_by_prompting_input_from_file(model):
raw_text += "( " + context + " ) " + topic + " =>"
else:
# args.prompt_type == "response":
# construct inputs for response generation
# args.prompt_type == "response"
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
knowledge = " ".join(word_tokenize(knowledge))
last_turn = turns[-1]
knowledge = knowledge.strip()
last_turn = last_turn.strip()
......@@ -137,9 +143,9 @@ def generate_samples_by_prompting_input_from_file(model):
for _, decode_tokens in enumerate(token_stream):
pass
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
......@@ -147,13 +153,11 @@ def generate_samples_by_prompting_input_from_file(model):
generated_output = trim_decode_tokens.split("\n")[0]
generated_output = generated_output.strip()
fname_out.write(generated_output)
fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
......@@ -174,4 +178,5 @@ def main():
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# perform the prompting
generate_samples_by_prompting_input_from_file(model)
#!/bin/bash
# Data preparation for our framework: preprocessing the WoW and WoI datasets
# The datasets can be downloaded through the following links:
# WoW: https://parl.ai/projects/wizard_of_wikipedia/
# WoI: https://parl.ai/projects/sea/
DIR=`pwd`
mkdir -p $DIR/tasks/knwl_dialo/data
......@@ -9,6 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp
# We provide the following script to process the raw data from Wizard of Internet
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_woi_dataset --input_file <PATH_OF_THE_INPUT_DATA> --output_file <PATH_OF_THE_OUTPUT_DATA>
# Obtain the knowledge generation prompts and response generation prompts
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func get_prompts --test_file <PATH_OF_THE_PROCESSED_TEST_DATA> --train_file <PATH_OF_THE_PROCESSED_TRAIN_DATA> --model_file <PATH_OF_THE_DPR_MODEL> --output_file <PATH_OF_THE_OUTPUT_FILE>
# Alternatively, we recommend you to directly download the already processed file through:
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1vP0eGxhkbWfeJ2dUUOEAflbOZq-Jlde_' -O data.gz
#!/bin/bash
# This script is used to evaluate the F1 or KF1 scores.
WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
......
#!/bin/bash
# Finetune a pretrained language model to generate the context-relevant knowledge
# The input is the dialogue context, and output is the relevant knowledge
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
......@@ -16,8 +20,6 @@ TRAIN_PATH=<Specify path for the training dataset>
TEST_PATH=<Specify path for the test dataset>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......@@ -31,17 +33,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--lr 1.5e-5 \
--min-lr 1.0e-5 \
--lr-decay-style cosine \
--log-interval 100 \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--save-interval 10000 \
--save ${OUTPUT_MODEL_PATH} \
--pretrained-checkpoint ${CHECKPOINT_PATH} \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.02 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
......@@ -51,7 +49,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task KNWL-DIALO-FINETUNE \
--module knowledge \
--spec-toks [SEP],[CTRL],[PAD] \
--train-data-path ${TRAIN_PATH} \
--test-data-path ${TEST_PATH} \
--max-seq-len 1024 \
--train-data ${TRAIN_PATH} \
--test-data ${TEST_PATH} \
--tokenizer-type GPT2BPETokenizer
#!/bin/bash
# Finetune a pretrained language model to generate the corresponding response
# The input is the dialogue context and knowledge, and the output is the response
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
......@@ -16,8 +20,6 @@ TRAIN_PATH=<Specify path for the training dataset>
TEST_PATH=<Specify path for the test dataset>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
......@@ -31,17 +33,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--lr 1.0e-5 \
--min-lr 5.0e-6 \
--lr-decay-style cosine \
--log-interval 100 \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--save-interval 10000 \
--save ${OUTPUT_MODEL_PATH} \
--pretrained-checkpoint ${CHECKPOINT_PATH} \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.02 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
......@@ -51,7 +49,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task KNWL-DIALO-FINETUNE \
--module response \
--spec-toks [SEP],[CTRL],[PAD] \
--train-data-path ${TRAIN_PATH} \
--test-data-path ${TEST_PATH} \
--max-seq-len 1024 \
--train-data ${TRAIN_PATH} \
--test-data ${TEST_PATH} \
--tokenizer-type GPT2BPETokenizer
#!/bin/bash
# Preparing the input file for the response generation (second-stage prompting)
DIR=`pwd`
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func prepare_input --test_file <PATH_OF_THE_PROCESSED_TEST_DATA> --knowledge_file <PATH_OF_THE_GENERATED_KNOWLEDGE_DATA> --output_file <PATH_OF_THE_OUTPUT_FILE>
#!/bin/bash
# Stage-1: Prompt a pretrained language model to generate the context-relevant knowledge
# The input contains prompts and current dialogue context, the output is the relevant knowledge
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
......@@ -10,25 +14,24 @@ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
CHECKPOINT_PATH=<Specify path for the language model>
INPUT_PATH=<Specific path for the input test dataset>
VOCAB_PATH=<Specify path for the vocab file>
MERGE_PATH=<Specify path for the merge file>
OUTPUT_PATH=<Speicifc path for the output>
PROMPT_PATH=<Specific path for the prompts>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--vocab-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-vocab.json \
--merge-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-merges.txt \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \
--fp16 \
--DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \
--out-seq-length 100 \
--sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \
......
#!/bin/bash
# Stage-2: Prompt a pretrained language model to generate the corresponding response
# The input contains prompts, current dialogue context, and generated knowledge in Stage-1
# The output is the corresponding response.
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
......@@ -10,25 +15,24 @@ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
CHECKPOINT_PATH=<Specify path for the language model>
INPUT_PATH=<Specific path for the input test dataset>
VOCAB_PATH=<Specify path for the vocab file>
MERGE_PATH=<Specify path for the merge file>
OUTPUT_PATH=<Speicifc path for the output>
PROMPT_PATH=<Specific path for the prompts>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--vocab-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-vocab.json \
--merge-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-merges.txt \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \
--fp16 \
--DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \
--out-seq-length 100 \
--sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \
......
"""Utils (functions) for both prompting and finetuning"""
import torch
from megatron import mpu
from megatron import get_args
......@@ -11,7 +13,11 @@ from megatron.model import Float16Module
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
"""
Build attention masks and position id for left to right model.
Different from the existing get_ltor_masks_and_position_ids function,
we add padding to the input sequences to make sure their lengths are the same.
"""
micro_batch_size, seq_length = data.size()
......@@ -38,6 +44,7 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
def switch(val1, val2, boolean):
"""Return either val1 or val2 depending on boolean"""
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
......@@ -46,6 +53,7 @@ def switch(val1, val2, boolean):
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
"""Forward step to get the outputs"""
# functions the correct size
args = get_args()
......@@ -73,24 +81,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def pad_batch(batch, pad_id, args):
"""Pad the context tokens using pad_id"""
context_lengths = []
for tokens in batch:
context_length = len(tokens)
# padding
if context_length < args.seq_length:
tokens.extend([pad_id] * (args.seq_length - context_length))
# record the original context length
context_lengths.append(context_length)
return batch, context_lengths
def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
# Get the attention mask and postition ids for the context tokens.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
......@@ -104,6 +116,7 @@ def get_batch(context_tokens):
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
"""Obtain batch-level generation outputs"""
args = get_args()
tokenizer = get_tokenizer()
......@@ -122,18 +135,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
counter = 0
org_context_length = context_length
# prepare batch size, context tokens, maximum length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
# start the generation process
while context_length <= (maxlen):
# forward and obtain the logits
output = forward_step(model, tokens,
position_ids,
attention_mask,
......@@ -143,11 +156,13 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
assert output is not None
logits = output[:, context_length - 1, :]
# generate tokens iteratively
if mpu.is_pipeline_last_stage():
prev = torch.argmax(logits, dim=-1).view(-1)
# start to add new tokens when the generated length
# exceeds the context length
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
......@@ -155,6 +170,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
# check whether the generation is finished
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
......@@ -189,13 +205,17 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
def get_token_stream(model, context_tokens):
"""Get output tokens iteratively"""
# get tokenizer
args = get_args()
tokenizer = get_tokenizer()
# padding for context tokens
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
# move tokens to CUDA
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
......@@ -206,9 +226,11 @@ def get_token_stream(model, context_tokens):
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
# prepare batch
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
# get generation outputs
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
......
......@@ -85,8 +85,6 @@ def get_tasks_args(parser):
' take from each question pool')
# parameters for the knowledgeable dialogue generation
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
......@@ -102,20 +100,12 @@ def get_tasks_args(parser):
help='using different prompts for different test samples')
group.add_argument('--module', type=str, default="",
help='either knowledge generation (knowledge) or response generation (response)')
group.add_argument('--train-data-path', type=str, default="",
help='datapath for training set')
group.add_argument('--test-data-path', type=str, default="",
help='datapath for test set')
group.add_argument('--guess-file', type=str, default="",
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default="",
help='datapath for golden sentences')
group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length')
group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens')
group.add_argument('--eval-prompting', action='store_true',
help='Whether to evaluate prompting')
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment