Commit 1dacc318 authored by zihanl's avatar zihanl
Browse files

delete some args and update comments

parent dfd43f49
# Multi-Stage Prompting for Knowledgeable Dialogue Generation # Multi-Stage Prompting for Knowledgeable Dialogue Generation
We present the steps to run our multi-stage dialogue prompting (MSDP), as well as the finetuning-based models (i.e., finetuning-based knowledge generation (FKG) and finetuning-based coversation model (FCM)). We present the steps to run our multi-stage dialogue prompting (MSDP), as well as the finetuning-based baselines (i.e., finetuning-based knowledge generation (FKG) and finetuning-based coversation model (FCM)).
## Multi-Stage Dialogue Prompting (MSDP) ## Multi-Stage Dialogue Prompting (MSDP)
...@@ -10,18 +10,18 @@ We present the steps to run our multi-stage dialogue prompting (MSDP), as well a ...@@ -10,18 +10,18 @@ We present the steps to run our multi-stage dialogue prompting (MSDP), as well a
2. Data Processing: We provide script ```tasks/knwl_dialo/scripts/data_processing.sh``` to process the data. 2. Data Processing: We provide script ```tasks/knwl_dialo/scripts/data_processing.sh``` to process the data.
### Knowledge Generation ### Knowledge Generation
1. The script ```tasks/knwl_dialo/scripts/prompt_knwl_gen.sh``` provides an example for how to perform the knowledge generation prompting. 1. The script ```tasks/knwl_dialo/scripts/prompt_knwl_gen.sh``` provides an example for how to perform the first-stage prompting for the knowledge generation.
2. The F1/FK1 score can be evaluated through ```tasks/knwl_dialo/scripts/eval_generation.sh```. Other automatic metrics (i.e., BLEU, METEOR, and ROUGE-L) follow the [nlg-eval](https://github.com/Maluuba/nlg-eval). 2. The F1/FK1 score can be evaluated through ```tasks/knwl_dialo/scripts/eval_generation.sh```. Other automatic metrics (i.e., BLEU, METEOR, and ROUGE-L) follow the [nlg-eval](https://github.com/Maluuba/nlg-eval).
### Response Generation ### Response Generation
1. Prepare the input file for the response generation (based on the previously generated knowledge file): 1. The script ```tasks/knwl_dialo/scripts/prep_respgen.sh``` helps to prepare the input file for the response generation (based on the previously generated knowledge file).
2. The script ```tasks/knwl_dialo/scripts/prompt_resp_gen.sh``` provides an example for how to perform the response generation prompting. 2. The script ```tasks/knwl_dialo/scripts/prompt_resp_gen.sh``` provides an example for how to perform the second-stage prompting for the response generation.
3. The automatic evaluations are the same as mentioned aboved for the knowledge generation. 3. The automatic evaluations are the same as mentioned aboved for the knowledge generation.
## Finetuning-based Models ## Finetuning-based Baselines
### FKG ### FKG
The script ```tasks/knwl_dialo/scripts/finetune_knwl_gen.sh``` provides an example for how to train a finetuning-based knowledge generation (FKG) model. The script ```tasks/knwl_dialo/scripts/finetune_knwl_gen.sh``` provides an example for how to train a finetuning-based knowledge generation model (FKG).
### FCM ### FCM
The script ```tasks/knwl_dialo/scripts/finetune_resp_gen.sh``` provides an example for how to train a finetuning-based conversational model (FCM). The script ```tasks/knwl_dialo/scripts/finetune_resp_gen.sh``` provides an example for how to train a finetuning-based conversational model (FCM).
......
"""Model evaluation"""
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -17,27 +19,28 @@ from tasks.knwl_dialo.metrics import F1Metric ...@@ -17,27 +19,28 @@ from tasks.knwl_dialo.metrics import F1Metric
from tqdm import tqdm from tqdm import tqdm
def test_dataset_provider(): def test_dataset_provider():
"""Build the test dataset for dialog/control module""" """Build the test dataset"""
args = get_args() args = get_args()
print_rank_0('> building the test dataset for %s module ...' \ print_rank_0('> building the test dataset for %s module ...' \
% args.train_module) % args.module)
if args.eval_prompting: if args.prompt_type != "":
print_rank_0('> evaluating ppl for prompting') print_rank_0('> evaluating ppl for prompting')
test_ds = build_test_dataset_for_prompting( test_ds = build_test_dataset_for_prompting(
test_data_path=args.test_data_path, test_data_path=args.test_data_path,
prompt_file=args.prompt_file, prompt_file=args.prompt_file,
train_module=args.train_module, module=args.module,
max_seq_len=args.max_seq_len, max_seq_len=args.seq_length,
num_prompt_examples=args.num_prompt_examples, num_prompt_examples=args.num_prompt_examples,
three_turns=args.three_turns, three_turns=args.three_turns,
dynamic_prompt=args.dynamic_prompt) dynamic_prompt=args.dynamic_prompt)
else: else:
print_rank_0('> evaluating ppl for finetuning')
test_ds = build_test_dataset( test_ds = build_test_dataset(
test_data_path=args.test_data_path, test_data_path=args.test_data_path,
train_module=args.train_module, module=args.module,
max_seq_len=args.max_seq_len, max_seq_len=args.seq_length,
last_turn=args.last_turn, last_turn=args.last_turn,
no_control_code=args.no_control_code, no_control_code=args.no_control_code,
add_separator=args.add_separator, add_separator=args.add_separator,
...@@ -45,7 +48,7 @@ def test_dataset_provider(): ...@@ -45,7 +48,7 @@ def test_dataset_provider():
remove_ctrl_sent=args.remove_ctrl_sent) remove_ctrl_sent=args.remove_ctrl_sent)
print_rank_0("> finished creating the test dataset for %s module ..." \ print_rank_0("> finished creating the test dataset for %s module ..." \
% args.train_module) % args.module)
print_rank_0('> test set size: %d' % len(test_ds)) print_rank_0('> test set size: %d' % len(test_ds))
args.eval_iters = len(test_ds) // args.global_batch_size args.eval_iters = len(test_ds) // args.global_batch_size
...@@ -68,6 +71,7 @@ def _build_test_iterator(test_dataset, task_collate_fn=None): ...@@ -68,6 +71,7 @@ def _build_test_iterator(test_dataset, task_collate_fn=None):
def evaluate_ppl(test_dataset_provider, model_provider, forward_step): def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
"""Evaluating perplexity"""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -110,6 +114,7 @@ def evaluate_ppl(test_dataset_provider, model_provider, forward_step): ...@@ -110,6 +114,7 @@ def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
def evaluate_f1(guess_file, answer_file): def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = [] guess_list = []
print_rank_0('reading %s' % guess_file) print_rank_0('reading %s' % guess_file)
......
"""Dialogue Finetuning""" """Finetuning a pretrained language model for knowledge/response generation"""
import torch import torch
from functools import partial from functools import partial
...@@ -42,7 +42,7 @@ def train_valid_datasets_provider(): ...@@ -42,7 +42,7 @@ def train_valid_datasets_provider():
train_data_path=args.train_data_path, train_data_path=args.train_data_path,
valid_data_path=args.test_data_path, valid_data_path=args.test_data_path,
module=args.module, module=args.module,
max_seq_len=args.max_seq_len, max_seq_len=args.seq_length,
seed=args.seed) seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.module) print_rank_0("> finished creating datasets for %s module ..." % args.module)
...@@ -135,30 +135,30 @@ def generate_samples_input_from_file(model): ...@@ -135,30 +135,30 @@ def generate_samples_input_from_file(model):
context_count = 0 context_count = 0
model.eval() model.eval()
# start the generation process
with torch.no_grad(): with torch.no_grad():
while True: while True:
raw_text_len = 0 raw_text_len = 0
if mpu.is_pipeline_first_stage() \ if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0: and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos] raw_text = all_raw_text[input_pos]
input_pos += 1 input_pos += 1
raw_text_len = len(raw_text) raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text) context_tokens = tokenizer.tokenize(raw_text)
else: else:
context_tokens = tokenizer.tokenize("EMPTY TEXT") context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0: if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos) print_rank_0("input_pos: %d" % input_pos)
# get the generation outputs
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream):
pass pass
# write the generation to the output file
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
...@@ -194,6 +194,7 @@ def run_generation(model_provider): ...@@ -194,6 +194,7 @@ def run_generation(model_provider):
assert len(model) == 1, "Above condition should have caught this" assert len(model) == 1, "Above condition should have caught this"
model = model[0] model = model[0]
# run generation
generate_samples_input_from_file(model) generate_samples_input_from_file(model)
...@@ -201,6 +202,7 @@ def main(): ...@@ -201,6 +202,7 @@ def main():
args = get_args() args = get_args()
if "FINETUNE" in args.task: if "FINETUNE" in args.task:
# finetune
finetune(train_valid_datasets_provider, model_provider, \ finetune(train_valid_datasets_provider, model_provider, \
forward_step=forward_step) forward_step=forward_step)
else: else:
......
...@@ -26,7 +26,6 @@ def normalize_answer(s): ...@@ -26,7 +26,6 @@ def normalize_answer(s):
s = s.lower() s = s.lower()
s = re_punc.sub(' ', s) s = re_punc.sub(' ', s)
s = re_art.sub(' ', s) s = re_art.sub(' ', s)
# TODO: this could almost certainly be faster with a regex \s+ -> ' '
s = ' '.join(s.split()) s = ' '.join(s.split())
return s return s
......
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import argparse import argparse
from nltk import word_tokenize from nltk import word_tokenize
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import json
def get_params(): def get_params():
parser = argparse.ArgumentParser(description="Preprocessing") parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default="") parser.add_argument("--func", type=str, default="",
parser.add_argument("--input_file", type=str, default="") help="choose to run which function")
parser.add_argument("--knowledge_file", type=str, default="") parser.add_argument("--input_file", type=str, default="",
parser.add_argument("--output_file", type=str, default="") help="path of the input file")
parser.add_argument("--knowledge_file", type=str, default="",
help="path of the knowledge file")
parser.add_argument("--test_file", type=str, default="",
help="path of the test file")
parser.add_argument("--train_file", type=str, default="",
help="path of the train file")
parser.add_argument("--output_file", type=str, default="",
help="path of the output file")
parser.add_argument("--model_file", type=str, default="",
help="path of the model file")
parser.add_argument("--seed", type=int, default=123456,
help="random seed")
params = parser.parse_args() params = parser.parse_args()
return params return params
...@@ -17,14 +33,17 @@ def get_params(): ...@@ -17,14 +33,17 @@ def get_params():
def process_wow_dataset(input_file, output_file): def process_wow_dataset(input_file, output_file):
""" """
expected processed format: This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
with open(input_file, "r") as fr: with open(input_file, "r") as fr:
dialog_data = json.load(fr) dialog_data = json.load(fr)
with open(output_file, "w") as fw: with open(output_file, "w") as fw:
for i, sample in enumerate(tqdm(dialog_data)): for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample
dialog = sample["dialog"] dialog = sample["dialog"]
context = [] context = []
...@@ -46,6 +65,7 @@ def process_wow_dataset(input_file, output_file): ...@@ -46,6 +65,7 @@ def process_wow_dataset(input_file, output_file):
assert len(checked_sentence) <= 1 assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0: if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0] checked_sentence = checked_sentence[0]
else: else:
...@@ -56,13 +76,15 @@ def process_wow_dataset(input_file, output_file): ...@@ -56,13 +76,15 @@ def process_wow_dataset(input_file, output_file):
else: else:
checked_passage = "no_passages_used" checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used": if checked_passage != "no_passages_used":
topic = checked_passage topic = checked_passage
else: else:
topic = sample["chosen_topic"] topic = sample["chosen_topic"]
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + checked_sentence + "\t" + text + "\n") # write to the output file
fw.write(topic + "\t" + " [SEP] ".join(context) + "\t" + \
checked_sentence + "\t" + text + "\n")
context.append(text) context.append(text)
else: else:
...@@ -71,6 +93,12 @@ def process_wow_dataset(input_file, output_file): ...@@ -71,6 +93,12 @@ def process_wow_dataset(input_file, output_file):
def process_woi_dataset(input_file, output_file): def process_woi_dataset(input_file, output_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
with open(output_path, "w") as fw: with open(output_path, "w") as fw:
with open(input_path, "r") as fr: with open(input_path, "r") as fr:
for i, line in tqdm(enumerate(fr)): for i, line in tqdm(enumerate(fr)):
...@@ -93,19 +121,19 @@ def process_woi_dataset(input_file, output_file): ...@@ -93,19 +121,19 @@ def process_woi_dataset(input_file, output_file):
search_text = item['text'] search_text = item['text']
elif action == "Wizard => Apprentice": elif action == "Wizard => Apprentice":
if len(turn_list) == 0: if len(turn_list) == 0:
turn = item['text'] turn = item['text']
turn_list.append(turn) turn_list.append(turn)
continue continue
# get knowledge sentence # get the relevant content
contents = item["context"]["contents"] contents = item["context"]["contents"]
selects = item["context"]["selected_contents"] selects = item["context"]["selected_contents"]
flag = selects[0][0] flag = selects[0][0]
selects = selects[1:] selects = selects[1:]
assert len(selects) == len(contents) assert len(selects) == len(contents)
# get the topic
if flag: if flag:
# no knowledge sentence is used # no knowledge sentence is used
topic = "no_topic" topic = "no_topic"
...@@ -121,33 +149,29 @@ def process_woi_dataset(input_file, output_file): ...@@ -121,33 +149,29 @@ def process_woi_dataset(input_file, output_file):
for c, s in zip(content, select): for c, s in zip(content, select):
if s: if s:
sent_list.append(c) sent_list.append(c)
if len(sent_list) == 0: if len(sent_list) == 0:
topic = "no_topic" topic = "no_topic"
sent_list = ["no_passages_used"] sent_list = ["no_passages_used"]
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list) dialog_context = " [SEP] ".join(turn_list)
knwl_sent = sent_list[0] knwl_sent = sent_list[0]
response = item['text'] response = item['text']
topic = topic.replace("\n", "") # processing
topic = topic.replace("\r", "") topic = topic.replace("\n", "").replace("\r", \
topic = topic.replace("\t", "") "").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
dialog_context = dialog_context.replace("\n", "") "").replace("\t", "")
dialog_context = dialog_context.replace("\r", "") knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
dialog_context = dialog_context.replace("\t", "") "").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
knwl_sent = knwl_sent.replace("\n", "") "").replace("\t", "")
knwl_sent = knwl_sent.replace("\r", "")
knwl_sent = knwl_sent.replace("\t", "") # write to the ouput file
response = response.replace("\n", "")
response = response.replace("\r", "")
response = response.replace("\t", "")
if topic != "no_topic": if topic != "no_topic":
fw.write(topic + "\t" + dialog_context + "\t" + knwl_sent + "\t" + response + "\n") fw.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
turn_list.append(response) turn_list.append(response)
...@@ -159,6 +183,296 @@ def process_woi_dataset(input_file, output_file): ...@@ -159,6 +183,296 @@ def process_woi_dataset(input_file, output_file):
assert action == "SearchAgent => Wizard" assert action == "SearchAgent => Wizard"
def get_database(test_datapath, train_datapath):
"""Get the database sorted by topics"""
# get test data topic list
print("> reading test data from %s" % test_datapath)
test_topics = {}
with open(test_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
test_topics[topic] = True
print("> reading data from %s" % train_datapath)
train_data_by_topic = {}
dialog_data_by_topic = {}
dialog_examples = []
with open(train_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2]
response = splits[3]
if knowledge == "no_passages_used":
continue
# get the instance
last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example
dialog_example = ""
dialog_example += "( " + topic + " )"
for turn in turns:
dialog_example += " "
dialog_example += turn
# check overlaps
if topic in test_topics:
if topic not in train_data_by_topic:
train_data_by_topic[topic] = [instance]
else:
train_data_by_topic[topic].append(instance)
if topic not in dialog_data_by_topic:
dialog_data_by_topic[topic] = [dialog_example]
else:
dialog_data_by_topic[topic].append(dialog_example)
# append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance))
return train_data_by_topic, dialog_data_by_topic, dialog_examples
emb_dict = {}
def select_prompts_based_on_similarity(
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
"""Select samples based on the similarity"""
with torch.no_grad():
# get the query embeddings
query_ids = tokenizer.encode(query)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate embeddings for the samples in the database
if topic in emb_dict:
example_embeddings = emb_dict[topic]
example_embeddings = example_embeddings.cuda()
else:
for idx, example in enumerate(dialog_list):
example_ids = tokenizer.encode(example)
example_ids = torch.LongTensor([example_ids]).cuda()
example_emb = encoder(input_ids=example_ids).pooler_output
if idx == 0:
example_embeddings = example_emb
else:
example_embeddings = torch.cat(
(example_embeddings, example_emb), dim=0)
emb_dict[topic] = example_embeddings.cpu()
# compare the similarity and select the topk samples
similarity_list = example_embeddings.matmul(query_emb)
_, indices = torch.topk(similarity_list, k=topk)
indices = indices.tolist()
indices = indices[::-1] # reverse the order
selected_prompts = []
for index in indices:
# index = index.item()
selected_prompts.append(prompt_list[index])
return selected_prompts
def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path):
"""Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath)
from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
print("> getting dialog embeddings")
with torch.no_grad():
for idx, example in tqdm(enumerate(dialog_examples)):
dialog = example[1]
dialog_ids = tokenizer.encode(dialog)
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
if idx == 0:
dialog_embeddings = dialog_emb
else:
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath)
count_out_of_list = 0
prompt_list_for_each_sample = []
with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
if topic not in train_data_by_topic:
count_out_of_list += 1
# calculate similarity
# get the query embedding
query_sent = ""
query_sent += "( " + topic + " )"
for turn in turns:
query_sent += " "
query_sent += turn
query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate the similarity
similarity_list = dialog_embeddings.matmul(query_emb)
_, indices = torch.sort(similarity_list)
indices = indices.tolist()
selected_topics = {}
selected_prompts = []
num_prompt = 0
for index in indices:
example = dialog_examples[index]
topic_temp = example[0]
if topic_temp not in selected_topics:
selected_topics[topic_temp] = True
selected_prompts.append(example[2])
num_prompt += 1
if num_prompt == 10:
break
# get the selected samples
example_list = selected_prompts[::-1]
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
else:
num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic]
# query_sent
query_sent = ""
query_sent += "( " + topic + " )"
for turn in turns:
query_sent += " "
query_sent += turn
dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == num_data_sample
# calculate the similarity
selected_examples = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample)
example_list = selected_examples
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
print("writing to %s" % output_prompt_path)
with open(output_prompt_path, "w") as f:
for instance in tqdm(prompt_list_for_each_sample):
json.dump(instance, f)
f.write("\n")
def prompt_selection_for_response_generation(input_path, output_path, seed):
"""Selecting prompts for the response generation"""
print("> Selecting prompts for the response generation")
print("> set random seed")
np.random.seed(seed)
prompt_example_list = []
print("> reading data from %s" % input_path)
with open(input_path, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
# get the topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")[-3:]
if knowledge == "no_passages_used":
continue
# calculate the overlap ratio
from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
response_token_list = response.split()
response_len = len(response_token_list)
num_overlap_token = 0
for token in response_token_list:
if token in knowledge_sent_token_dict:
num_overlap_token += 1
# filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue
prompt_example = ""
# add dialog context
prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + turns[-1] + " "
prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example)
print("> shuffle the prompt examples (total %d)" % len(prompt_example_list))
np.random.shuffle(prompt_example_list)
print("> Prompt example:")
print(prompt_example_list[0])
print("> writing to %s" % output_path)
with open(output_path, "w") as f:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for i in tqdm(range(20)):
example = prompt_example_list[i]
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knowledge_file, output_file):
"""Preparing inputs for the response generation"""
# get the knowledge list
with open(knowledge_file, "r") as f:
knowledge_list = f.readlines()
with open(test_file, "r") as fr:
with open(output_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
# prepare topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
response = splits[3]
knowledge = knowledge_list[line_num]
knowledge = knowledge.strip()
if "<|endoftext|>" in knowledge:
knowledge = knowledge.replace("<|endoftext|>", "")
# write to the output file
fw.write(topic + "\t" + dialog_context + "\t" \
+ knowledge + "\t" + response + "\n")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -168,3 +482,13 @@ if __name__ == "__main__": ...@@ -168,3 +482,13 @@ if __name__ == "__main__":
elif params.func == "process_woi_dataset": elif params.func == "process_woi_dataset":
process_woi_dataset(params.input_file, params.output_file) process_woi_dataset(params.input_file, params.output_file)
elif params.func == "get_prompts":
prompt_selection_for_knowledge_generation(
params.test_file, params.train_file, params.model_file, params.output_file)
prompt_selection_for_response_generation(
params.train_file, params.output_file, params.seed)
elif params.func == "prepare_input":
prepare_input_for_response_generation(
params.test_file, params.knowledge_file, params.output_file)
"""Prompting the pretrained language model to generate knowledge/response"""
import json import json
import torch import torch
from nltk import word_tokenize from nltk import word_tokenize
...@@ -27,7 +29,9 @@ def model_provider(pre_process=True, post_process=True): ...@@ -27,7 +29,9 @@ def model_provider(pre_process=True, post_process=True):
def generate_samples_by_prompting_input_from_file(model): def generate_samples_by_prompting_input_from_file(model):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -57,17 +61,17 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -57,17 +61,17 @@ def generate_samples_by_prompting_input_from_file(model):
line_dict = json.loads(line) line_dict = json.loads(line)
key = list(line_dict.keys())[0] key = list(line_dict.keys())[0]
# get the prompt examples based on the key
if key not in prompt_examples_dict: if key not in prompt_examples_dict:
prompt_examples = line_dict[key] prompt_examples = line_dict[key]
prompt = "" prompt = ""
for instance in prompt_examples: for instance in prompt_examples:
instance = instance.strip() instance = instance.strip()
prompt += instance + " \n" prompt += instance + " \n"
prompt_examples_dict[key] = prompt prompt_examples_dict[key] = prompt
else: else:
# prompts are fixed for all test samples
with open(args.prompt_file, "r") as f: with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines() prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples] prompt_examples = prompt_examples[:args.num_prompt_examples]
...@@ -77,13 +81,14 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -77,13 +81,14 @@ def generate_samples_by_prompting_input_from_file(model):
instance = instance.strip() instance = instance.strip()
prompt += instance + " \n" prompt += instance + " \n"
# only two prompt types (i.e., knowledge and response) are allowed
assert args.prompt_type in ["knowledge", "response"] assert args.prompt_type in ["knowledge", "response"]
context_count = 0 context_count = 0
model.eval() model.eval()
# perform prompting
with torch.no_grad(): with torch.no_grad():
while True: while True:
raw_text_len = 0 raw_text_len = 0
if mpu.is_pipeline_first_stage() \ if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0: and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos] input_str = all_raw_text[input_pos]
...@@ -92,16 +97,17 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -92,16 +97,17 @@ def generate_samples_by_prompting_input_from_file(model):
control_codes = splits[0].split(" [CTRL] ") control_codes = splits[0].split(" [CTRL] ")
topic = control_codes[0] topic = control_codes[0]
# first add the prompt into the inputs
if args.dynamic_prompt: if args.dynamic_prompt:
turns = splits[1].split(" [SEP] ") turns = splits[1].split(" [SEP] ")
last_turn = turns[-1] last_turn = turns[-1]
key = topic + " " + last_turn key = topic + " " + last_turn
raw_text = prompt_examples_dict[key] raw_text = prompt_examples_dict[key]
else: else:
raw_text = prompt raw_text = prompt
if args.prompt_type == "knowledge": if args.prompt_type == "knowledge":
# construct inputs for knowledge generation
turns = splits[1].split(" [SEP] ") turns = splits[1].split(" [SEP] ")
context = turns[-1] context = turns[-1]
if " -> " in raw_text and " => " not in raw_text: if " -> " in raw_text and " => " not in raw_text:
...@@ -110,11 +116,11 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -110,11 +116,11 @@ def generate_samples_by_prompting_input_from_file(model):
raw_text += "( " + context + " ) " + topic + " =>" raw_text += "( " + context + " ) " + topic + " =>"
else: else:
# args.prompt_type == "response": # construct inputs for response generation
# args.prompt_type == "response"
turns = splits[1].split(" [SEP] ") turns = splits[1].split(" [SEP] ")
knowledge = splits[2] knowledge = splits[2]
knowledge = " ".join(word_tokenize(knowledge)) knowledge = " ".join(word_tokenize(knowledge))
last_turn = turns[-1] last_turn = turns[-1]
knowledge = knowledge.strip() knowledge = knowledge.strip()
last_turn = last_turn.strip() last_turn = last_turn.strip()
...@@ -137,9 +143,9 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -137,9 +143,9 @@ def generate_samples_by_prompting_input_from_file(model):
for _, decode_tokens in enumerate(token_stream): for _, decode_tokens in enumerate(token_stream):
pass pass
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
...@@ -147,13 +153,11 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -147,13 +153,11 @@ def generate_samples_by_prompting_input_from_file(model):
generated_output = trim_decode_tokens.split("\n")[0] generated_output = trim_decode_tokens.split("\n")[0]
generated_output = generated_output.strip() generated_output = generated_output.strip()
fname_out.write(generated_output) fname_out.write(generated_output)
fname_out.write("\n") fname_out.write("\n")
raw_text = None raw_text = None
context_count += 1 context_count += 1
if input_pos == input_count: if input_pos == input_count:
return return
...@@ -174,4 +178,5 @@ def main(): ...@@ -174,4 +178,5 @@ def main():
assert len(model) == 1, "Above condition should have caught this" assert len(model) == 1, "Above condition should have caught this"
model = model[0] model = model[0]
# perform the prompting
generate_samples_by_prompting_input_from_file(model) generate_samples_by_prompting_input_from_file(model)
#!/bin/bash #!/bin/bash
# Data preparation for our framework: preprocessing the WoW and WoI datasets
# The datasets can be downloaded through the following links:
# WoW: https://parl.ai/projects/wizard_of_wikipedia/
# WoI: https://parl.ai/projects/sea/
DIR=`pwd` DIR=`pwd`
mkdir -p $DIR/tasks/knwl_dialo/data mkdir -p $DIR/tasks/knwl_dialo/data
...@@ -9,6 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp ...@@ -9,6 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp
# We provide the following script to process the raw data from Wizard of Internet # We provide the following script to process the raw data from Wizard of Internet
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_woi_dataset --input_file <PATH_OF_THE_INPUT_DATA> --output_file <PATH_OF_THE_OUTPUT_DATA> python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_woi_dataset --input_file <PATH_OF_THE_INPUT_DATA> --output_file <PATH_OF_THE_OUTPUT_DATA>
# Obtain the knowledge generation prompts and response generation prompts
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func get_prompts --test_file <PATH_OF_THE_PROCESSED_TEST_DATA> --train_file <PATH_OF_THE_PROCESSED_TRAIN_DATA> --model_file <PATH_OF_THE_DPR_MODEL> --output_file <PATH_OF_THE_OUTPUT_FILE>
# Alternatively, we recommend you to directly download the already processed file through: # Alternatively, we recommend you to directly download the already processed file through:
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1vP0eGxhkbWfeJ2dUUOEAflbOZq-Jlde_' -O data.gz wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1vP0eGxhkbWfeJ2dUUOEAflbOZq-Jlde_' -O data.gz
#!/bin/bash #!/bin/bash
# This script is used to evaluate the F1 or KF1 scores.
WORLD_SIZE=1 WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
......
#!/bin/bash #!/bin/bash
# Finetune a pretrained language model to generate the context-relevant knowledge
# The input is the dialogue context, and output is the relevant knowledge
# The size of the pretrained language model is 357M
WORLD_SIZE=8 WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
...@@ -16,8 +20,6 @@ TRAIN_PATH=<Specify path for the training dataset> ...@@ -16,8 +20,6 @@ TRAIN_PATH=<Specify path for the training dataset>
TEST_PATH=<Specify path for the test dataset> TEST_PATH=<Specify path for the test dataset>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
...@@ -31,17 +33,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -31,17 +33,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--lr 1.5e-5 \ --lr 1.5e-5 \
--min-lr 1.0e-5 \ --min-lr 1.0e-5 \
--lr-decay-style cosine \ --lr-decay-style cosine \
--log-interval 100 \
--vocab-file ${VOCAB_PATH} \ --vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \ --merge-file ${MERGE_PATH} \
--save-interval 10000 \ --save-interval 10000 \
--save ${OUTPUT_MODEL_PATH} \ --save ${OUTPUT_MODEL_PATH} \
--pretrained-checkpoint ${CHECKPOINT_PATH} \ --pretrained-checkpoint ${CHECKPOINT_PATH} \
--clip-grad 1.0 \
--weight-decay 0.1 \ --weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \ --adam-beta2 0.95 \
--init-method-std 0.02 \
--log-params-norm \ --log-params-norm \
--log-num-zeros-in-grad \ --log-num-zeros-in-grad \
--fp16 \ --fp16 \
...@@ -51,7 +49,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -51,7 +49,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task KNWL-DIALO-FINETUNE \ --task KNWL-DIALO-FINETUNE \
--module knowledge \ --module knowledge \
--spec-toks [SEP],[CTRL],[PAD] \ --spec-toks [SEP],[CTRL],[PAD] \
--train-data-path ${TRAIN_PATH} \ --train-data ${TRAIN_PATH} \
--test-data-path ${TEST_PATH} \ --test-data ${TEST_PATH} \
--max-seq-len 1024 \
--tokenizer-type GPT2BPETokenizer --tokenizer-type GPT2BPETokenizer
#!/bin/bash #!/bin/bash
# Finetune a pretrained language model to generate the corresponding response
# The input is the dialogue context and knowledge, and the output is the response
# The size of the pretrained language model is 357M
WORLD_SIZE=8 WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
...@@ -16,8 +20,6 @@ TRAIN_PATH=<Specify path for the training dataset> ...@@ -16,8 +20,6 @@ TRAIN_PATH=<Specify path for the training dataset>
TEST_PATH=<Specify path for the test dataset> TEST_PATH=<Specify path for the test dataset>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
...@@ -31,17 +33,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -31,17 +33,13 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--lr 1.0e-5 \ --lr 1.0e-5 \
--min-lr 5.0e-6 \ --min-lr 5.0e-6 \
--lr-decay-style cosine \ --lr-decay-style cosine \
--log-interval 100 \
--vocab-file ${VOCAB_PATH} \ --vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \ --merge-file ${MERGE_PATH} \
--save-interval 10000 \ --save-interval 10000 \
--save ${OUTPUT_MODEL_PATH} \ --save ${OUTPUT_MODEL_PATH} \
--pretrained-checkpoint ${CHECKPOINT_PATH} \ --pretrained-checkpoint ${CHECKPOINT_PATH} \
--clip-grad 1.0 \
--weight-decay 0.1 \ --weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \ --adam-beta2 0.95 \
--init-method-std 0.02 \
--log-params-norm \ --log-params-norm \
--log-num-zeros-in-grad \ --log-num-zeros-in-grad \
--fp16 \ --fp16 \
...@@ -51,7 +49,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -51,7 +49,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task KNWL-DIALO-FINETUNE \ --task KNWL-DIALO-FINETUNE \
--module response \ --module response \
--spec-toks [SEP],[CTRL],[PAD] \ --spec-toks [SEP],[CTRL],[PAD] \
--train-data-path ${TRAIN_PATH} \ --train-data ${TRAIN_PATH} \
--test-data-path ${TEST_PATH} \ --test-data ${TEST_PATH} \
--max-seq-len 1024 \
--tokenizer-type GPT2BPETokenizer --tokenizer-type GPT2BPETokenizer
#!/bin/bash
# Preparing the input file for the response generation (second-stage prompting)
DIR=`pwd`
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func prepare_input --test_file <PATH_OF_THE_PROCESSED_TEST_DATA> --knowledge_file <PATH_OF_THE_GENERATED_KNOWLEDGE_DATA> --output_file <PATH_OF_THE_OUTPUT_FILE>
#!/bin/bash #!/bin/bash
# Stage-1: Prompt a pretrained language model to generate the context-relevant knowledge
# The input contains prompts and current dialogue context, the output is the relevant knowledge
# The size of the pretrained language model is 357M
WORLD_SIZE=8 WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
...@@ -10,25 +14,24 @@ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ ...@@ -10,25 +14,24 @@ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
CHECKPOINT_PATH=<Specify path for the language model> CHECKPOINT_PATH=<Specify path for the language model>
INPUT_PATH=<Specific path for the input test dataset> INPUT_PATH=<Specific path for the input test dataset>
VOCAB_PATH=<Specify path for the vocab file>
MERGE_PATH=<Specify path for the merge file>
OUTPUT_PATH=<Speicifc path for the output> OUTPUT_PATH=<Speicifc path for the output>
PROMPT_PATH=<Specific path for the prompts> PROMPT_PATH=<Specific path for the prompts>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--seq-length 2048 \ --seq-length 2048 \
--max-position-embeddings 2048 \ --max-position-embeddings 2048 \
--micro-batch-size 1 \ --micro-batch-size 1 \
--vocab-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-vocab.json \ --vocab-file ${VOCAB_PATH} \
--merge-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-merges.txt \ --merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \ --load ${CHECKPOINT_PATH} \
--fp16 \ --fp16 \
--DDP-impl torch \ --DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \ --tokenizer-type GPT2BPETokenizer \
--out-seq-length 100 \
--sample-input-file ${INPUT_PATH} \ --sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \ --sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \ --prompt-file ${PROMPT_PATH} \
......
#!/bin/bash #!/bin/bash
# Stage-2: Prompt a pretrained language model to generate the corresponding response
# The input contains prompts, current dialogue context, and generated knowledge in Stage-1
# The output is the corresponding response.
# The size of the pretrained language model is 357M
WORLD_SIZE=8 WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
...@@ -10,25 +15,24 @@ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \ ...@@ -10,25 +15,24 @@ DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
CHECKPOINT_PATH=<Specify path for the language model> CHECKPOINT_PATH=<Specify path for the language model>
INPUT_PATH=<Specific path for the input test dataset> INPUT_PATH=<Specific path for the input test dataset>
VOCAB_PATH=<Specify path for the vocab file>
MERGE_PATH=<Specify path for the merge file>
OUTPUT_PATH=<Speicifc path for the output> OUTPUT_PATH=<Speicifc path for the output>
PROMPT_PATH=<Specific path for the prompts> PROMPT_PATH=<Specific path for the prompts>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--seq-length 2048 \ --seq-length 2048 \
--max-position-embeddings 2048 \ --max-position-embeddings 2048 \
--micro-batch-size 1 \ --micro-batch-size 1 \
--vocab-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-vocab.json \ --vocab-file ${VOCAB_PATH} \
--merge-file /gpfs/fs1/projects/gpu_adlr/datasets/nlp/gpt2_indexed_dataset/bpe/gpt2-merges.txt \ --merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \ --load ${CHECKPOINT_PATH} \
--fp16 \ --fp16 \
--DDP-impl torch \ --DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \ --tokenizer-type GPT2BPETokenizer \
--out-seq-length 100 \
--sample-input-file ${INPUT_PATH} \ --sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \ --sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \ --prompt-file ${PROMPT_PATH} \
......
"""Utils (functions) for both prompting and finetuning"""
import torch import torch
from megatron import mpu from megatron import mpu
from megatron import get_args from megatron import get_args
...@@ -11,7 +13,11 @@ from megatron.model import Float16Module ...@@ -11,7 +13,11 @@ from megatron.model import Float16Module
def get_ltor_attention_masks_and_position_ids(data, eod_token_id): def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model.""" """
Build attention masks and position id for left to right model.
Different from the existing get_ltor_masks_and_position_ids function,
we add padding to the input sequences to make sure their lengths are the same.
"""
micro_batch_size, seq_length = data.size() micro_batch_size, seq_length = data.size()
...@@ -38,6 +44,7 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id): ...@@ -38,6 +44,7 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
def switch(val1, val2, boolean): def switch(val1, val2, boolean):
"""Return either val1 or val2 depending on boolean"""
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2 return (1 - boolean) * val1 + boolean * val2
...@@ -46,6 +53,7 @@ def switch(val1, val2, boolean): ...@@ -46,6 +53,7 @@ def switch(val1, val2, boolean):
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None, layer_past=None, get_key_value=None,
forward_method_parallel_output=None): forward_method_parallel_output=None):
"""Forward step to get the outputs"""
# functions the correct size # functions the correct size
args = get_args() args = get_args()
...@@ -73,24 +81,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -73,24 +81,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
def pad_batch(batch, pad_id, args): def pad_batch(batch, pad_id, args):
"""Pad the context tokens using pad_id"""
context_lengths = [] context_lengths = []
for tokens in batch: for tokens in batch:
context_length = len(tokens) context_length = len(tokens)
# padding
if context_length < args.seq_length: if context_length < args.seq_length:
tokens.extend([pad_id] * (args.seq_length - context_length)) tokens.extend([pad_id] * (args.seq_length - context_length))
# record the original context length
context_lengths.append(context_length) context_lengths.append(context_length)
return batch, context_lengths return batch, context_lengths
def get_batch(context_tokens): def get_batch(context_tokens):
"""Generate batch from context tokens.""" """Generate batch from context tokens."""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# Move to GPU. # Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda() tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids. # Get the attention mask and postition ids for the context tokens.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids( attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens, tokens,
tokenizer.eod, tokenizer.eod,
...@@ -104,6 +116,7 @@ def get_batch(context_tokens): ...@@ -104,6 +116,7 @@ def get_batch(context_tokens):
def sample_sequence_batch(model, context_tokens, context_lengths, def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids, attention_mask, position_ids,
maxlen=None, type_ids=None): maxlen=None, type_ids=None):
"""Obtain batch-level generation outputs"""
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -122,18 +135,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -122,18 +135,18 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
counter = 0 counter = 0
org_context_length = context_length org_context_length = context_length
# prepare batch size, context tokens, maximum length
layer_past = None layer_past = None
batch_size = context_tokens.size(0) batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda() is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens tokens = context_tokens
if maxlen is None: if maxlen is None:
maxlen = args.seq_length - 1 maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
# start the generation process
while context_length <= (maxlen): while context_length <= (maxlen):
# forward and obtain the logits
output = forward_step(model, tokens, output = forward_step(model, tokens,
position_ids, position_ids,
attention_mask, attention_mask,
...@@ -143,11 +156,13 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -143,11 +156,13 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
assert output is not None assert output is not None
logits = output[:, context_length - 1, :] logits = output[:, context_length - 1, :]
# generate tokens iteratively
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
prev = torch.argmax(logits, dim=-1).view(-1) prev = torch.argmax(logits, dim=-1).view(-1)
# start to add new tokens when the generated length
# exceeds the context length
started = context_lengths <= context_length started = context_lengths <= context_length
new_tokens = switch( new_tokens = switch(
tokens[:, context_length].view(-1), prev, started) tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens tokens[:, context_length] = new_tokens
...@@ -155,6 +170,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -155,6 +170,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group) torch.distributed.broadcast(new_tokens, src, group)
# check whether the generation is finished
done_token = (prev == eos_id).byte() & started.byte() done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool() just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length lengths[just_finished.view(-1)] = context_length
...@@ -189,13 +205,17 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -189,13 +205,17 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
def get_token_stream(model, context_tokens): def get_token_stream(model, context_tokens):
"""Get output tokens iteratively"""
# get tokenizer
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
# padding for context tokens
context_tokens, context_lengths = pad_batch(context_tokens, context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args) tokenizer.eod, args)
# move tokens to CUDA
context_tokens_tensor = torch.cuda.LongTensor(context_tokens) context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths) context_length_tensor = torch.cuda.LongTensor(context_lengths)
...@@ -206,9 +226,11 @@ def get_token_stream(model, context_tokens): ...@@ -206,9 +226,11 @@ def get_token_stream(model, context_tokens):
mpu.get_tensor_model_parallel_src_rank(), mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group()) group=mpu.get_tensor_model_parallel_group())
# prepare batch
context_length = context_length_tensor.min().item() context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor) tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
# get generation outputs
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor, batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor, context_length_tensor,
attention_mask, position_ids) attention_mask, position_ids)
......
...@@ -85,8 +85,6 @@ def get_tasks_args(parser): ...@@ -85,8 +85,6 @@ def get_tasks_args(parser):
' take from each question pool') ' take from each question pool')
# parameters for the knowledgeable dialogue generation # parameters for the knowledgeable dialogue generation
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None, group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, ' help='Get input from file instead of interactive mode, '
'each line is an input.') 'each line is an input.')
...@@ -102,20 +100,12 @@ def get_tasks_args(parser): ...@@ -102,20 +100,12 @@ def get_tasks_args(parser):
help='using different prompts for different test samples') help='using different prompts for different test samples')
group.add_argument('--module', type=str, default="", group.add_argument('--module', type=str, default="",
help='either knowledge generation (knowledge) or response generation (response)') help='either knowledge generation (knowledge) or response generation (response)')
group.add_argument('--train-data-path', type=str, default="",
help='datapath for training set')
group.add_argument('--test-data-path', type=str, default="",
help='datapath for test set')
group.add_argument('--guess-file', type=str, default="", group.add_argument('--guess-file', type=str, default="",
help='datapath for generated sentences') help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default="", group.add_argument('--answer-file', type=str, default="",
help='datapath for golden sentences') help='datapath for golden sentences')
group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length')
group.add_argument('--spec-toks', type=str, default=None, group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens') help='additional special tokens')
group.add_argument('--eval-prompting', action='store_true',
help='Whether to evaluate prompting')
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment