Commit fb3328fe authored by zihanl's avatar zihanl
Browse files

upload knwl_dialo

parent ec202742
......@@ -7,10 +7,10 @@ dist/
tensorboard
commands/
commands_new/
commands_others/
commands_final/
*.log
logs
*.so
*.out
train_gpt_conv.py
dialogctrl/
control_gen/
\ No newline at end of file
dialogctrl/
\ No newline at end of file
......@@ -263,6 +263,7 @@ def generate_samples_prompt_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
from nltk import word_tokenize
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
......@@ -282,16 +283,35 @@ def generate_samples_prompt_input_from_file(model):
fname_out = open(sample_output_file, "w")
# Read the prompt file
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
if args.dynamic_prompt:
prompt_examples_dict = {}
with open(args.prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
assert args.prompt_type in ["context", "keyphrase"]
assert args.prompt_type in ["knowledge", "knowledge_notopic", "dialogue", "dialogue_notopic"]
context_count = 0
model.eval()
with torch.no_grad():
......@@ -306,25 +326,77 @@ def generate_samples_prompt_input_from_file(model):
control_codes = splits[0].split(" [CTRL] ")
topic = control_codes[0]
raw_text = prompt
if args.prompt_type == "context":
if args.dynamic_prompt:
turns = splits[1].split(" [SEP] ")
context = turns[-1]
raw_text += "( " + context + " ) " + topic + " :"
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
else:
keyphrase_list = control_codes[1:]
raw_text = prompt
for i, keyphrase in enumerate(keyphrase_list):
if i == 0:
raw_text += "( "
if args.prompt_type == "knowledge":
turns = splits[1].split(" [SEP] ")
context = turns[-1]
raw_text += "( " + context + " ) " + topic + " =>"
# raw_text += "( " + context + " ) " + topic + ":"
# raw_text += "( " + context + " ) " + topic + " ->"
elif args.prompt_type == "knowledge_notopic":
turns = splits[1].split(" [SEP] ")[-3:]
for j, turn in enumerate(turns):
if j != 0:
raw_text += " "
else:
raw_text += "; "
raw_text += keyphrase
raw_text += "( " + turn + " )"
raw_text += " =>"
elif args.prompt_type == "dialogue":
turns = splits[1].split(" [SEP] ")
# context = turns[-1]
ctrl_sent = splits[2]
ctrl_sent = " ".join(word_tokenize(ctrl_sent))
# ## version one
# turns = turns[-3:]
# raw_text += "Topic: " + topic + ". "
# if len(turns) == 2:
# for idx, turn in enumerate(turns):
# if idx % 2 == 0:
# raw_text += "System: " + turn + " "
# else:
# raw_text += "User: " + turn + " "
# else:
# for idx, turn in enumerate(turns):
# if idx % 2 == 0:
# raw_text += "User: " + turn + " "
# else:
# raw_text += "System: " + turn + " "
# raw_text += "We know that: " + ctrl_sent + " "
# raw_text += "Therefore, the System will say:"
## version two
last_turn = turns[-1]
ctrl_sent = ctrl_sent.strip()
last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". "
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + ctrl_sent + " "
raw_text += "System replies:"
if len(keyphrase_list) > 0:
raw_text += " ) "
raw_text += topic + " :"
else:
turns = splits[1].split(" [SEP] ")
# context = turns[-1]
ctrl_sent = splits[2]
ctrl_sent = " ".join(word_tokenize(ctrl_sent))
## version two
last_turn = turns[-1]
ctrl_sent = ctrl_sent.strip()
last_turn = last_turn.strip()
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + ctrl_sent + " "
raw_text += "System replies:"
input_pos += 1
raw_text_len = len(raw_text)
......
"""Build Dataset for Controllable Coversational Model"""
import os
import torch
import numpy as np
from megatron import get_tokenizer
from megatron import print_rank_0
def read_data(tokenizer, data_path, train_module):
"""read and tokenize dialog data"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
length_split = len(splits)
assert length_split == 2 or length_split == 3 or length_split == 4
if train_module == "dialog":
# if length_split == 2:
# continue
dialog_context = splits[0]
if length_split > 2:
ctrl_sent = splits[-2]
response = splits[-1]
# only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
# input_ids
input_ids = []
if length_split > 2:
input_ids.extend(tokenizer.tokenize("( " + ctrl_sent + " )"))
for idx, turn in enumerate(turns):
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
input_ids.extend(tokenizer.tokenize(turn))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
if length_split == 2:
continue
dialog_context = splits[0]
ctrl_sent = splits[-2]
ctrl_code = splits[1] if length_split == 4 else None
turns = dialog_context.split(" [SEP] ")
# put control code at the begginning
input_ids = []
if ctrl_code:
ctrl_code_list = ctrl_code.split(" [CTRL] ")
for code in ctrl_code_list:
input_ids.extend(tokenizer.tokenize("( " + code + " )"))
turns = turns[-3:]
for turn in turns:
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
input_ids.extend(tokenizer.tokenize(turn))
# output_ids
outputs = ctrl_sent
output_ids = tokenizer.tokenize(outputs)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! " \
"(either dialog or cnotrol))")
return data_list
def read_data_v2(tokenizer, data_path, train_module,
last_turn=False, no_control_code=False, add_separator=False,
add_ctrl_code_to_dialog=False, remove_ctrl_sent=False):
"""
Read and tokenize data for version 2 (v2) data files.
Format: control code \t dialog context \t control sentence \t response.
Response only comes from the wizard.
Currently, this function is used to build test dataset for calculating PPL.
"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.rstrip()
splits = line.split("\t")
assert len(splits) == 4
control_code = splits[0]
dialog_context = splits[1]
control_sent = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
if train_module == "dialog":
# input_ids
if add_ctrl_code_to_dialog:
ctrl_code = control_code.split(" [CTRL] ")[0]
input_ids = tokenizer.tokenize("( " + ctrl_code + " )")
if not remove_ctrl_sent and control_sent != "no_passages_used":
input_ids.extend(tokenizer.tokenize("( " + control_sent + " )")[:256])
else:
if remove_ctrl_sent or control_sent == "no_passages_used":
input_ids = []
else:
input_ids = tokenizer.tokenize("( " + control_sent + " )")[:256]
for turn in turns:
if add_separator:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
if add_separator:
input_ids.extend(tokenizer.tokenize(":"))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
# skip example without control sentences
if control_sent == "no_passages_used":
continue
input_ids = []
if not no_control_code:
ctrl_code_list = control_code.split(" [CTRL] ")[:3]
# only choose maximum three control codes
for code in ctrl_code_list:
if len(code) > 0:
input_ids.extend(tokenizer.tokenize("( " + code + " )"))
if last_turn:
input_ids.extend(tokenizer.tokenize(turns[-1]))
else:
for turn in turns:
if add_separator:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
if add_separator:
input_ids.extend(tokenizer.tokenize(":"))
output_ids = tokenizer.tokenize(control_sent)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! " \
"(either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
np.random.shuffle(data)
return data
class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, sep_id, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len = max_seq_len
self.sep_id = sep_id
self.pad_id = pad_id
self.eod_id = eod_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
# assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
text = input_ids + output_ids + [self.eod_id]
loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len]
else:
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), \
"loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_datasets(train_data_path, valid_data_path, train_module,
max_seq_len, seed, last_turn, no_control_code,
add_separator, add_ctrl_code_to_dialog, remove_ctrl_sent):
"""Build train, valid, and test datasets."""
# dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
# train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
# valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
# test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer = get_tokenizer()
# train_data_list = read_data(tokenizer, train_data_path, train_module)
train_data_list = read_data_v2(tokenizer, train_data_path, train_module,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent)
valid_data_list = read_data_v2(tokenizer, valid_data_path, train_module,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid datasets
train_dataset = ControlDialogDataset(train_data_list,
max_seq_len,
sep_id=tokenizer.sep_id,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list,
max_seq_len,
sep_id=tokenizer.sep_id,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset
def build_test_dataset(test_data_path, train_module, max_seq_len,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent):
tokenizer = get_tokenizer()
test_data_list = read_data_v2(tokenizer, test_data_path, train_module,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent)
test_dataset = ControlDialogDataset(test_data_list,
max_seq_len,
sep_id=tokenizer.sep_id,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return test_dataset
import torch
from megatron import print_rank_0
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones(
(micro_batch_size, seq_length, seq_length), device=data.device)).view(
micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# # reset attentino mask and position ids
# # Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
# eod_index = eod_index.clone()
# # Loop through EOD indecies:
# prev_index = 0
# for j in range(eod_index.size()[0]):
# i = eod_index[j]
# # Mask attention loss.
# attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# # Reset positions.
# position_ids[b, (i + 1):] -= (i + 1 - prev_index)
# prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, position_ids
\ No newline at end of file
"""Build Dataset for Controllable Coversational Model"""
import os
import torch
import numpy as np
from megatron import get_tokenizer
from megatron import print_rank_0
def read_data_for_finetuning(tokenizer, data_path, module):
"""
Data Format: topic \t dialog context \t knowledge \t response.
"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.rstrip()
splits = line.split("\t")
assert len(splits) == 4
topic = splits[0].split(" [CTRL] ")[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
if module == "response":
# input_ids
input_ids = tokenizer.tokenize("( " + topic + " )")
if knowledge != "no_passages_used":
input_ids.extend(tokenizer.tokenize("( " + knowledge + " )")[:256])
for turn in turns:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(":"))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif module == "knowledge":
# skip example without knowledge sentences
if knowledge == "no_passages_used":
continue
input_ids = []
input_ids.extend(tokenizer.tokenize("( " + topic + " )"))
for turn in turns:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(":"))
output_ids = tokenizer.tokenize(knowledge)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct module name! " \
"(either dialog or cnotrol))")
return data_list
def read_data_for_prompting(tokenizer, test_data_path, prompt_file,
module, num_prompt_examples, dynamic_prompt):
# get prompts
if dynamic_prompt:
import json
prompt_examples_dict = {}
with open(prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt_examples = prompt_examples[:num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[topic] = prompt
else:
with open(prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
data_list = []
with open(test_data_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0].split(" [CTRL] ")[0]
turns = splits[1].split(" [SEP] ")[-3:]
last_turn = turns[-1]
ctrl_sent = splits[2]
response = splits[3]
if dynamic_prompt:
prompt = prompt_examples_dict[topic]
if module == "response":
# input seq
input_seq = prompt
input_seq += "Topic: " + topic + ". "
input_seq += "User says: " + last_turn + " "
input_seq += "We know that: " + ctrl_sent + " "
input_seq += "System replies:"
# output seq
output_seq = response
input_ids = tokenizer.tokenize(input_seq)
output_ids = tokenizer.tokenize(output_seq)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif module == "knowledge":
# input seq
input_seq = prompt
input_seq += "( " + last_turn + " ) " + topic + " =>"
# output seq
output_seq = ctrl_sent
input_ids = tokenizer.tokenize(input_seq)
output_ids = tokenizer.tokenize(output_seq)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct module name! " \
"(either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
np.random.shuffle(data)
return data
class KnwlDialoDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len = max_seq_len
self.pad_id = pad_id
self.eod_id = eod_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
text = input_ids + output_ids + [self.eod_id]
loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len]
else:
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), \
"loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_datasets(train_data_path, valid_data_path, module,
max_seq_len, seed):
"""Build train, valid, and test datasets."""
tokenizer = get_tokenizer()
train_data_list = read_data_for_finetuning(tokenizer, train_data_path, module)
valid_data_list = read_data_for_finetuning(tokenizer, valid_data_path, module)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid datasets
train_dataset = KnwlDialoDataset(train_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
valid_dataset = KnwlDialoDataset(valid_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset
def build_test_dataset(test_data_path, module, max_seq_len):
tokenizer = get_tokenizer()
test_data_list = read_data_for_finetuning(tokenizer, test_data_path, module)
test_dataset = KnwlDialoDataset(test_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return test_dataset
def build_test_dataset_for_prompting(test_data_path, prompt_file, module, max_seq_len,
num_prompt_examples, dynamic_prompt):
tokenizer = get_tokenizer()
test_data_list = read_data_for_prompting(tokenizer, test_data_path, prompt_file, module, \
num_prompt_examples, dynamic_prompt)
test_dataset = KnwlDialoDataset(test_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return test_dataset
......@@ -7,9 +7,13 @@ from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.checkpointing import load_checkpoint
from tasks.finetune_utils import build_data_loader
from tasks.dialctrl.data import build_test_dataset
from tasks.dialctrl.finetune import model_provider, process_batch, loss_func, forward_step
from tasks.dialctrl.metrics import F1Metric
from tasks.knwl_dialo.data import build_test_dataset
from tasks.knwl_dialo.data import build_test_dataset_for_prompting
from tasks.knwl_dialo.finetune import model_provider
from tasks.knwl_dialo.finetune import process_batch
from tasks.knwl_dialo.finetune import loss_func
from tasks.knwl_dialo.finetune import forward_step
from tasks.knwl_dialo.metrics import F1Metric
from tqdm import tqdm
def test_dataset_provider():
......@@ -18,15 +22,27 @@ def test_dataset_provider():
print_rank_0('> building the test dataset for %s module ...' \
% args.train_module)
test_ds = build_test_dataset(
test_data_path=args.test_data_path,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
remove_ctrl_sent=args.remove_ctrl_sent)
if args.eval_prompting:
print_rank_0('> evaluating ppl for prompting')
test_ds = build_test_dataset_for_prompting(
test_data_path=args.test_data_path,
prompt_file=args.prompt_file,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
num_prompt_examples=args.num_prompt_examples,
three_turns=args.three_turns,
dynamic_prompt=args.dynamic_prompt)
else:
test_ds = build_test_dataset(
test_data_path=args.test_data_path,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
remove_ctrl_sent=args.remove_ctrl_sent)
print_rank_0("> finished creating the test dataset for %s module ..." \
% args.train_module)
......@@ -93,7 +109,7 @@ def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
print_rank_0('done :-)')
def evaluate_f1(guess_file, answer_file, remove_stopwords):
def evaluate_f1(guess_file, answer_file):
guess_list = []
print_rank_0('reading %s' % guess_file)
......@@ -116,7 +132,7 @@ def evaluate_f1(guess_file, answer_file, remove_stopwords):
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list, remove_stopwords)
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
......@@ -124,10 +140,10 @@ def evaluate_f1(guess_file, answer_file, remove_stopwords):
def main():
args = get_args()
if 'ppl' in args.task:
evaluate_ppl(test_dataset_provider, model_provider, forward_step)
elif 'f1' in args.task:
evaluate_f1(args.guess_file, args.answer_file, args.remove_stopwords)
evaluate_f1(args.guess_file, args.answer_file)
"""Controllable Dialogue Finetuning"""
"""Dialogue Finetuning"""
import torch
from functools import partial
from megatron import mpu
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model import GPTModel
from megatron.training import evaluate_and_print_results
from megatron.training import get_model
from megatron.utils import average_losses_across_data_parallel_group
from megatron.initialize import initialize_megatron
from tasks.finetune_utils import finetune
from tasks.dialctrl.data import build_train_valid_datasets
from tasks.dialctrl.utils import get_ltor_attention_masks_and_position_ids
from tasks.knwl_dialo.data import build_train_valid_datasets
from tasks.knwl_dialo.utils import get_ltor_attention_masks_and_position_ids
from tasks.knwl_dialo.utils import get_token_stream
def model_provider(pre_process=True, post_process=True):
......@@ -113,8 +116,98 @@ def forward_step(batch, model):
return output_tensor, partial(loss_func, loss_mask)
def generate_samples_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
context_count = 0
model.eval()
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
if "\r" in trim_decode_tokens:
trim_decode_tokens = trim_decode_tokens.replace("\r", "")
if "\n" in trim_decode_tokens:
trim_decode_tokens = trim_decode_tokens.replace("\n", "")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
def run_generation(model_provider):
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
generate_samples_input_from_file(model)
def main():
finetune(train_valid_datasets_provider, model_provider, \
forward_step=forward_step)
args = get_args()
if "finetune" in args.task:
finetune(train_valid_datasets_provider, model_provider, \
forward_step=forward_step)
else:
# generate
run_generation(model_provider)
......@@ -61,7 +61,7 @@ class F1Metric:
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str, rm_sw: bool):
def compute_each_pair(guess: str, answer: str):
if answer == "":
return None, None, None
if guess == "":
......@@ -69,26 +69,17 @@ class F1Metric:
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
if rm_sw:
g_tokens = remove_stopwords(g_tokens)
a_tokens = remove_stopwords(a_tokens)
if len(a_tokens) == 0:
return None, None, None
if len(g_tokens) == 0:
return 0, 0, 0
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str], rm_sw=False):
def compute_all_pairs(guesses: List[str], answers: List[str]):
# additional augment:
# rm_sw: whether to remove stopwords
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer, rm_sw)
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
......
import json
import torch
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from tasks.knwl_dialo.utils import get_token_stream
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def generate_samples_by_prompting_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
# Read the prompt file
if args.dynamic_prompt:
prompt_examples_dict = {}
with open(args.prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
assert args.prompt_type in ["knowledge", "response"]
context_count = 0
model.eval()
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
input_str = input_str.strip()
splits = input_str.split("\t")
control_codes = splits[0].split(" [CTRL] ")
topic = control_codes[0]
if args.dynamic_prompt:
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
else:
raw_text = prompt
if args.prompt_type == "knowledge":
turns = splits[1].split(" [SEP] ")
context = turns[-1]
raw_text += "( " + context + " ) " + topic + " =>"
else:
# args.prompt_type == "response":
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
knowledge = " ".join(word_tokenize(knowledge))
last_turn = turns[-1]
knowledge = knowledge.strip()
last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". "
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + knowledge + " "
raw_text += "System replies:"
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
generated_output = trim_decode_tokens.split("\n")[0]
generated_output = generated_output.strip()
fname_out.write(generated_output)
fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
def main():
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
generate_samples_by_prompting_input_from_file(model)
import torch
from megatron import mpu
from megatron import get_args
from megatron import get_tokenizer
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones(
(micro_batch_size, seq_length, seq_length), device=data.device)).view(
micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, position_ids
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
if get_key_value:
output_tensor, layer_past = output_tensor
send_forward(output_tensor)
args.seq_length = orig_seq_length
if get_key_value:
return output_tensor, layer_past
return output_tensor
def pad_batch(batch, pad_id, args):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
tokens.extend([pad_id] * (args.seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
return tokens, attention_mask, position_ids
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen):
output = forward_step(model, tokens,
position_ids,
attention_mask,
tokentype_ids=type_ids,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, context_length - 1, :]
if mpu.is_pipeline_last_stage():
prev = torch.argmax(logits, dim=-1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
def get_token_stream(model, context_tokens):
args = get_args()
tokenizer = get_tokenizer()
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths
else:
yield None, None
......@@ -84,9 +84,24 @@ def get_tasks_args(parser):
help='Av.rank validation: how many other negatives to'
' take from each question pool')
# finetune for controllable dialogue
group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)')
# parameters for the knowledgeable dialogue generation
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument('--prompt-file', type=str, default="",
help='prompting file')
group.add_argument('--prompt-type', type=str, default="",
help='prompt type (knowledge or response)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument('--dynamic-prompt', action='store_true', default=False,
help='using different prompts for different test samples')
group.add_argument('--module', type=str, default="",
help='either knowledge generation (knowledge) or response generation (response)')
group.add_argument('--train-data-path', type=str, default="",
help='datapath for training set')
group.add_argument('--test-data-path', type=str, default="",
......@@ -99,29 +114,8 @@ def get_tasks_args(parser):
help='maximum sequence length')
group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens')
group.add_argument('--last-turn', action='store_true',
help='only use last turn for control model')
group.add_argument('--no-control-code', action='store_true',
help='removing control code in the training for control model')
group.add_argument('--remove-stopwords', action='store_true',
help='removing stopwords when evaluating F1-score')
group.add_argument('--add-separator', action='store_true',
help='add separator between turns and add colon before generation')
group.add_argument('--add-ctrl-code-to-dialog', action='store_true',
help='add control code in the dialog modeling')
group.add_argument('--remove-ctrl-sent', action='store_true',
help='dont use control sentence in dialog modeling')
# finetune for controllable generation
group.add_argument('--wiki-path', type=str, default="",
help='data path for the wikipedia corpus')
group.add_argument('--tokenized-path', type=str, default="",
help='data path for the tokenized file')
group.add_argument('--prop', type=float, default=1.0,
help='Proportion of data used for training')
group.add_argument('--max-instance', type=int, default=10000000,
help='Proportion of data used for training')
group.add_argument('--eval-prompting', action='store_true',
help='Whether to evaluate prompting')
return parser
......@@ -146,12 +140,12 @@ if __name__ == '__main__':
from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main
elif args.task == 'control-gen':
from control_gen.finetune import main
elif args.task == 'dialctrl':
from dialctrl.finetune import main
elif args.task in ['dialctrl-eval-ppl', 'dialctrl-eval-f1']:
from dialctrl.evaluate import main
elif args.task == 'knwl-dialo-prompt':
from knwl_dialo.prompt import main
elif args.task == ['knwl-dialo-finetune', 'knwl-dialo-gen']:
from knwl_dialo.finetune import main
elif args.task in ['knwl-dialo-eval-ppl', 'knwl-dialo-eval-f1']:
from knwl_dialo.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
......
......@@ -76,6 +76,7 @@ def add_text_generate_args(parser):
help='additional special tokens')
group.add_argument('--line-by-line', action="store_true",
help='generate samples line by line')
group.add_argument('--prompt', action="store_true",
help='generate samples based on prompting')
group.add_argument('--prompt-file', type=str, default="",
......@@ -84,6 +85,10 @@ def add_text_generate_args(parser):
help='prompt type (context or keyphrase)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument("--noknowledge", action='store_true', default=False,
help='Do not use knowledge in prompting')
group.add_argument('--dynamic-prompt', action='store_true', default=False,
help='using different prompts for different test samples')
return parser
......@@ -114,13 +119,7 @@ def main():
if args.num_samples == 0:
if args.sample_input_file != None:
args.micro_batch_size = 1
if args.line_by_line:
if args.prompt:
generate_samples_prompt_input_from_file(model)
else:
generate_samples_line_by_line_input_from_file(model)
else:
generate_samples_input_from_file(model)
generate_samples_input_from_file(model)
else:
generate_samples_interactive(model)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment