Commit ec202742 authored by zihanl's avatar zihanl
Browse files

add recent code and prompting

parent 7b799b7c
......@@ -3,4 +3,14 @@ __pycache__
# Distribution / packaging
build/
dist/
*.egg-info/
\ No newline at end of file
*.egg-info/
tensorboard
commands/
commands_new/
*.log
logs
*.so
*.out
train_gpt_conv.py
dialogctrl/
control_gen/
\ No newline at end of file
......@@ -24,6 +24,7 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
......@@ -190,6 +191,362 @@ def generate_samples_input_from_file(model):
raw_text = None
context_count += 1
def generate_samples_line_by_line_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
context_count = 0
model.eval()
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
if "\r" in trim_decode_tokens:
trim_decode_tokens = trim_decode_tokens.replace("\r", "")
if "\n" in trim_decode_tokens:
trim_decode_tokens = trim_decode_tokens.replace("\n", "")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
def generate_samples_prompt_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
# Read the prompt file
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
assert args.prompt_type in ["context", "keyphrase"]
context_count = 0
model.eval()
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
input_str = input_str.strip()
splits = input_str.split("\t")
control_codes = splits[0].split(" [CTRL] ")
topic = control_codes[0]
raw_text = prompt
if args.prompt_type == "context":
turns = splits[1].split(" [SEP] ")
context = turns[-1]
raw_text += "( " + context + " ) " + topic + " :"
else:
keyphrase_list = control_codes[1:]
for i, keyphrase in enumerate(keyphrase_list):
if i == 0:
raw_text += "( "
else:
raw_text += "; "
raw_text += keyphrase
if len(keyphrase_list) > 0:
raw_text += " ) "
raw_text += topic + " :"
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
generated_output = trim_decode_tokens.split("\n")[0]
generated_output = generated_output.strip()
fname_out.write(generated_output)
fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
def dialog_with_gpt_control_interactive(conv_model, ctrl_model, add_separtor):
args = get_args()
tokenizer = get_tokenizer()
conv_model.eval()
ctrl_model.eval()
dialog_history = []
with torch.no_grad():
while True:
ctrl_model_input_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
# input @@ to separate the control code and current turn
input_text = input(">>> ")
while not input_text:
print("Input should not be empty!")
input_text = input(">>> ")
assert " @@ " in input_text, "Please input with a correct template"
splits = input_text.split(" @@ ")
ctrl_code = splits[0]
curr_turn = splits[1]
prev_two_turns = ""
if add_separtor:
for i, turn in enumerate(dialog_history[-2:]):
if i == 0:
prev_two_turns = "<< " + turn + " >>"
else:
prev_two_turns += " "
prev_two_turns += "<< " + turn + " >>"
else:
prev_two_turns = " ".join(dialog_history[-2:])
dialog_history.append(curr_turn)
print("\nHistory:", prev_two_turns)
print("User:", curr_turn)
if add_separtor:
curr_turn = "<< " + curr_turn + " >>"
if prev_two_turns != "":
dialog_context = prev_two_turns + " " + curr_turn
else:
dialog_context = curr_turn
ctrl_input = ctrl_code + " " + dialog_context
if add_separtor:
ctrl_input += " :"
ctrl_input_text_len = len(ctrl_input)
ctrl_context_tokens = tokenizer.tokenize(ctrl_input)
else:
ctrl_context_tokens = tokenizer.tokenize("EMPTY TEXT")
token_stream = get_token_stream(ctrl_model, [ctrl_context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
control_sent = tokenizer.detokenize(
decode_tokens)[ctrl_input_text_len:]
control_sent = control_sent.replace("<|endoftext|>", "")
print("\nControl Sentence:", control_sent)
if control_sent != "":
control_sent = "( " + control_sent + " )"
conv_input = control_sent + " " + dialog_context
else:
conv_input = dialog_context
conv_input_text_len = len(conv_input)
conv_context_tokens = tokenizer.tokenize(conv_input)
token_stream = get_token_stream(conv_model, [conv_context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
response = tokenizer.detokenize(
decode_tokens)[conv_input_text_len:]
response = response.replace("<|endoftext|>", "")
print("\nChatbot:", response)
dialog_history.append(response)
def dialog_with_dpr_control_interactive(conv_model, ctrl_model, ctrl_tokenizer,
knowledge_corpus, knowledge_corpus_emb, add_separtor):
args = get_args()
tokenizer = get_tokenizer()
conv_model.eval()
ctrl_model.eval()
dialog_history = []
with torch.no_grad():
while True:
input_text = input(">>> ")
while not input_text:
print("Input should not be empty!")
input_text = input(">>> ")
assert " @@ " in input_text, "Please input with a correct template"
splits = input_text.split(" @@ ")
ctrl_code = splits[0]
curr_turn = splits[1]
prev_two_turns = " ".join(dialog_history[-2:])
prev_two_turns_v2 = ""
if add_separtor:
for i, turn in enumerate(dialog_history[-2:]):
if i == 0:
prev_two_turns_v2 = "<< " + turn + " >>"
else:
prev_two_turns_v2 += " "
prev_two_turns_v2 += "<< " + turn + " >>"
else:
prev_two_turns_v2 = prev_two_turns
dialog_history.append(curr_turn)
print("\nHistory:", prev_two_turns_v2)
print("\nUser:", curr_turn)
if prev_two_turns != "":
dialog_context = prev_two_turns + " " + curr_turn
else:
dialog_context = curr_turn
if add_separtor:
curr_turn = "<< " + curr_turn + " >>"
dialog_context_v2 = prev_two_turns_v2 + curr_turn
else:
dialog_context_v2 = dialog_context
ctrl_input = ctrl_code + " " + dialog_context
ctrl_input_ids = ctrl_tokenizer.encode(ctrl_input)
ctrl_input_ids = torch.LongTensor([ctrl_input_ids]).cuda()
attn_masks = torch.ones(1, ctrl_input_ids.size()[-1]).cuda()
query_emb = ctrl_model(input_ids=ctrl_input_ids,
attention_mask=attn_masks).pooler_output # (1,768)
logits = knowledge_corpus_emb.matmul(query_emb[0])
retrieved_idx = torch.argmax(logits).item()
control_sent = knowledge_corpus[retrieved_idx].strip()
print("\nControl Sentence:", control_sent)
if control_sent != "":
control_sent = "( " + control_sent + " )"
conv_input = control_sent + " " + dialog_context_v2
else:
conv_input = dialog_context_v2
conv_input_text_len = len(conv_input)
conv_context_tokens = tokenizer.tokenize(conv_input)
token_stream = get_token_stream(conv_model, [conv_context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
response = tokenizer.detokenize(
decode_tokens)[conv_input_text_len:]
response = response.replace("<|endoftext|>", "")
print("\nChatbot:", response)
dialog_history.append(response)
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# codebase. The lm-evaluation-harness code can now call this function
......
......@@ -32,20 +32,15 @@ def read_data(tokenizer, data_path, train_module):
turns = turns[-3:]
# input_ids
input_ids = []
if length_split > 2:
input_ids.extend(tokenizer.tokenize("( " + ctrl_sent + " )"))
for idx, turn in enumerate(turns):
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
if idx == 0:
input_ids = tokenizer.tokenize(turn)
else:
# input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(turn))
if length_split > 2:
# when there is control sentence, add it into the input_ids
# input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids.extend(tokenizer.tokenize("( " + ctrl_sent + " ) ."))
# output_ids
output_ids = tokenizer.tokenize(response)
......@@ -59,23 +54,6 @@ def read_data(tokenizer, data_path, train_module):
ctrl_code = splits[1] if length_split == 4 else None
turns = dialog_context.split(" [SEP] ")
# last_turn = turns[-1]
# turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
input_ids = []
if ctrl_code:
......@@ -96,11 +74,99 @@ def read_data(tokenizer, data_path, train_module):
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))")
raise ValueError("Please input a correct train-module name! " \
"(either dialog or cnotrol))")
return data_list
def read_data_v2(tokenizer, data_path, train_module,
last_turn=False, no_control_code=False, add_separator=False,
add_ctrl_code_to_dialog=False, remove_ctrl_sent=False):
"""
Read and tokenize data for version 2 (v2) data files.
Format: control code \t dialog context \t control sentence \t response.
Response only comes from the wizard.
Currently, this function is used to build test dataset for calculating PPL.
"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.rstrip()
splits = line.split("\t")
assert len(splits) == 4
control_code = splits[0]
dialog_context = splits[1]
control_sent = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
if train_module == "dialog":
# input_ids
if add_ctrl_code_to_dialog:
ctrl_code = control_code.split(" [CTRL] ")[0]
input_ids = tokenizer.tokenize("( " + ctrl_code + " )")
if not remove_ctrl_sent and control_sent != "no_passages_used":
input_ids.extend(tokenizer.tokenize("( " + control_sent + " )")[:256])
else:
if remove_ctrl_sent or control_sent == "no_passages_used":
input_ids = []
else:
input_ids = tokenizer.tokenize("( " + control_sent + " )")[:256]
for turn in turns:
if add_separator:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
if add_separator:
input_ids.extend(tokenizer.tokenize(":"))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
# skip example without control sentences
if control_sent == "no_passages_used":
continue
input_ids = []
if not no_control_code:
ctrl_code_list = control_code.split(" [CTRL] ")[:3]
# only choose maximum three control codes
for code in ctrl_code_list:
if len(code) > 0:
input_ids.extend(tokenizer.tokenize("( " + code + " )"))
if last_turn:
input_ids.extend(tokenizer.tokenize(turns[-1]))
else:
for turn in turns:
if add_separator:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
if add_separator:
input_ids.extend(tokenizer.tokenize(":"))
output_ids = tokenizer.tokenize(control_sent)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! " \
"(either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
......@@ -125,7 +191,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
......@@ -140,29 +206,62 @@ class ControlDialogDataset(torch.utils.data.Dataset):
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)}
return {"text": np.array(text, dtype=np.int64), \
"loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max_seq_len, seed):
def build_train_valid_datasets(train_data_path, valid_data_path, train_module,
max_seq_len, seed, last_turn, no_control_code,
add_separator, add_ctrl_code_to_dialog, remove_ctrl_sent):
"""Build train, valid, and test datasets."""
dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
# dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
# train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
# valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
# test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer = get_tokenizer()
train_data_list = read_data(tokenizer, train_data_path, train_module)
valid_data_list = read_data(tokenizer, valid_data_path, train_module)
test_data_list = read_data(tokenizer, test_data_path, train_module)
# train_data_list = read_data(tokenizer, train_data_path, train_module)
train_data_list = read_data_v2(tokenizer, train_data_path, train_module,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent)
valid_data_list = read_data_v2(tokenizer, valid_data_path, train_module,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid, and test datasets
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
# build train, valid datasets
train_dataset = ControlDialogDataset(train_data_list,
max_seq_len,
sep_id=tokenizer.sep_id,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list,
max_seq_len,
sep_id=tokenizer.sep_id,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset
def build_test_dataset(test_data_path, train_module, max_seq_len,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent):
tokenizer = get_tokenizer()
test_data_list = read_data_v2(tokenizer, test_data_path, train_module,
last_turn, no_control_code, add_separator,
add_ctrl_code_to_dialog, remove_ctrl_sent)
test_dataset = ControlDialogDataset(test_data_list,
max_seq_len,
sep_id=tokenizer.sep_id,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset, test_dataset
return test_dataset
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.checkpointing import load_checkpoint
from tasks.finetune_utils import build_data_loader
from tasks.dialctrl.data import build_test_dataset
from tasks.dialctrl.finetune import model_provider, process_batch, loss_func, forward_step
from tasks.dialctrl.metrics import F1Metric
from tqdm import tqdm
def test_dataset_provider():
"""Build the test dataset for dialog/control module"""
args = get_args()
print_rank_0('> building the test dataset for %s module ...' \
% args.train_module)
test_ds = build_test_dataset(
test_data_path=args.test_data_path,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
remove_ctrl_sent=args.remove_ctrl_sent)
print_rank_0("> finished creating the test dataset for %s module ..." \
% args.train_module)
print_rank_0('> test set size: %d' % len(test_ds))
args.eval_iters = len(test_ds) // args.global_batch_size
print_rank_0('> evaluation iteration: %d' % args.eval_iters)
return test_ds
def _build_test_iterator(test_dataset, task_collate_fn=None):
"""Test dataloader."""
args = get_args()
print_rank_0('building test dataloader ...')
# Test loader
test_dataloader = build_data_loader(test_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last,
task_collate_fn)
test_iterator = test_dataloader.__iter__()
return test_iterator
def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
args = get_args()
timers = get_timers()
# test dataloader.
timers('test dataset/dataloder').start()
test_dataset = test_dataset_provider()
test_iterator = _build_test_iterator(test_dataset)
timers('test dataset/dataloder').stop()
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
timers('pretrained checkpoint').start()
if args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
original_rng = args.no_load_rng
args.no_load_rng = True
iteration = load_checkpoint(model, None, None)
args.load = original_load
args.no_load_rng = original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['test dataset/dataloder', 'model and optimizer',
'pretrained checkpoint'])
print_rank_0('evaluating ...')
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
test_iterator, model,
iteration, False)
print_rank_0('done :-)')
def evaluate_f1(guess_file, answer_file, remove_stopwords):
guess_list = []
print_rank_0('reading %s' % guess_file)
with open(guess_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if "<|endoftext|>" in line:
line = line.replace("<|endoftext|>", "")
guess_list.append(line)
answer_list = []
print_rank_0('reading %s' % answer_file)
with open(answer_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if line == "no_passages_used":
line = ""
answer_list.append(line)
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list, remove_stopwords)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
def main():
args = get_args()
if 'ppl' in args.task:
evaluate_ppl(test_dataset_provider, model_provider, forward_step)
elif 'f1' in args.task:
evaluate_f1(args.guess_file, args.answer_file, args.remove_stopwords)
......@@ -12,7 +12,7 @@ from megatron.model import GPTModel
from megatron.training import evaluate_and_print_results
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import finetune
from tasks.dialctrl.data import build_train_valid_test_datasets
from tasks.dialctrl.data import build_train_valid_datasets
from tasks.dialctrl.utils import get_ltor_attention_masks_and_position_ids
......@@ -35,16 +35,27 @@ def train_valid_datasets_provider():
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
train_ds, valid_ds, _ = build_train_valid_test_datasets(
data_folder=args.data_folder,
dataset_name=args.dataset_name,
train_ds, valid_ds = build_train_valid_datasets(
train_data_path=args.train_data_path,
valid_data_path=args.test_data_path,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
seed=args.seed)
seed=args.seed,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
remove_ctrl_sent=args.remove_ctrl_sent)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
print_rank_0('> Train size: %d' % len(train_ds))
print_rank_0('> Validation size: %d' % len(valid_ds))
args.eval_interval = len(train_ds) // args.global_batch_size
print_rank_0(' > evaluation interval: %d' % args.eval_interval)
print_rank_0('> evaluation interval: %d' % args.eval_interval)
args.eval_iters = len(valid_ds) // args.global_batch_size
print_rank_0('> evaluation iteration: %d' % args.eval_iters)
return train_ds, valid_ds
......
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from collections import Counter
from typing import List
import numpy as np
import re
from nltk.corpus import stopwords
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
stopword_list = stopwords.words('english')
stopword_list = stopword_list + ["n's", "'s"]
stopword_dict = {token: True for token in stopword_list}
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
# TODO: this could almost certainly be faster with a regex \s+ -> ' '
s = ' '.join(s.split())
return s
def remove_stopwords(token_list):
new_list = []
for token in token_list:
if token in stopword_dict:
continue
new_list.append(token)
return new_list
class F1Metric:
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str, rm_sw: bool):
if answer == "":
return None, None, None
if guess == "":
return 0, 0, 0
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
if rm_sw:
g_tokens = remove_stopwords(g_tokens)
a_tokens = remove_stopwords(a_tokens)
if len(a_tokens) == 0:
return None, None, None
if len(g_tokens) == 0:
return 0, 0, 0
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str], rm_sw=False):
# additional augment:
# rm_sw: whether to remove stopwords
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer, rm_sw)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
......@@ -8,7 +8,9 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length)
attention_mask = torch.tril(torch.ones(
(micro_batch_size, seq_length, seq_length), device=data.device)).view(
micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
......
......@@ -87,15 +87,41 @@ def get_tasks_args(parser):
# finetune for controllable dialogue
group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="",
help='data folder (path of the data folder)')
group.add_argument('--dataset-name', type=str, default="",
help='dataset name (e.g., wizard_of_wikipedia)')
group.add_argument('--train-data-path', type=str, default="",
help='datapath for training set')
group.add_argument('--test-data-path', type=str, default="",
help='datapath for test set')
group.add_argument('--guess-file', type=str, default="",
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default="",
help='datapath for golden sentences')
group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length')
group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]",
group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens')
group.add_argument('--last-turn', action='store_true',
help='only use last turn for control model')
group.add_argument('--no-control-code', action='store_true',
help='removing control code in the training for control model')
group.add_argument('--remove-stopwords', action='store_true',
help='removing stopwords when evaluating F1-score')
group.add_argument('--add-separator', action='store_true',
help='add separator between turns and add colon before generation')
group.add_argument('--add-ctrl-code-to-dialog', action='store_true',
help='add control code in the dialog modeling')
group.add_argument('--remove-ctrl-sent', action='store_true',
help='dont use control sentence in dialog modeling')
# finetune for controllable generation
group.add_argument('--wiki-path', type=str, default="",
help='data path for the wikipedia corpus')
group.add_argument('--tokenized-path', type=str, default="",
help='data path for the tokenized file')
group.add_argument('--prop', type=float, default=1.0,
help='Proportion of data used for training')
group.add_argument('--max-instance', type=int, default=10000000,
help='Proportion of data used for training')
return parser
......@@ -120,8 +146,12 @@ if __name__ == '__main__':
from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main
elif args.task == 'control-gen':
from control_gen.finetune import main
elif args.task == 'dialctrl':
from dialctrl.finetune import main
elif args.task in ['dialctrl-eval-ppl', 'dialctrl-eval-f1']:
from dialctrl.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
......
"""Sample Generate Controllable Dialog Model"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import argparse
import torch
from transformers import DPRQuestionEncoderTokenizer
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import dialog_with_gpt_control_interactive, dialog_with_dpr_control_interactive
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False,
pre_process=pre_process, post_process=post_process)
return model
def add_control_dialog_generate_args(parser):
"""Text generation arguments."""
group = parser.add_argument_group(title='text generation')
group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0,
help='Top k sampling.')
group.add_argument("--out-seq-length", type=int, default=1024,
help='Size of the output generated text.')
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
group.add_argument("--ctrl-type", type=str, default="",
help="Either dpr or gpt")
group.add_argument("--ctrl-hidden-size", type=int, default=1024,
help="hidden-size of gpt control model")
group.add_argument("--ctrl-num-layers", type=int, default=24,
help="num-layers of gpt control model")
group.add_argument("--ctrl-num-attention-heads", type=int, default=16,
help="num-attention-heads of gpt control model")
group.add_argument("--ctrl-gpt-load", type=str, default="",
help="checkpoint path of the gpt control model")
group.add_argument("--ctrl-dpr-load", type=str, default="",
help="checkpoint path of the dpr control model")
group.add_argument("--knowledge-corpus-path", type=str, default="",
help="The path for the knowledge corpus")
group.add_argument("--knowledge-corpus-emb", type=str, default="",
help="The path for the knowledge embedding")
group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens')
group.add_argument('--add-separator', action="store_true",
help='Add separator for the inputs')
return parser
def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_control_dialog_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up conversational model
conv_model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(conv_model, None, None)
assert len(conv_model) == 1, "Above condition should have caught this"
conv_model = conv_model[0]
# Set up control model
assert args.ctrl_type in ["gpt", "dpr"], \
"please input a correct control model type"
if args.ctrl_type == "gpt":
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.hidden_size = args.ctrl_hidden_size
args.ffn_hidden_size = 4 * args.hidden_size
args.num_layers = args.ctrl_num_layers
args.num_attention_heads = args.ctrl_num_attention_heads
args.load = args.ctrl_gpt_load
ctrl_model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(ctrl_model, None, None)
ctrl_model = ctrl_model[0]
dialog_with_gpt_control_interactive(conv_model, ctrl_model, args.add_separator)
else:
print_rank_0("> Loading model from %s" % args.ctrl_dpr_load)
ctrl_model = torch.load(args.ctrl_dpr_load)
ctrl_model.cuda()
ctrl_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
print_rank_0("> Loading knowledge corpus and embeddings")
with open(args.knowledge_corpus_path, "r") as f:
knowledge_corpus = f.readlines()
knowledge_corpus_emb = torch.load(args.knowledge_corpus_emb)
knowledge_corpus_emb = knowledge_corpus_emb.cuda()
assert knowledge_corpus_emb.size()[0] == len(knowledge_corpus), \
"The size of knowledge corpus and embeddings should be the same"
dialog_with_dpr_control_interactive(conv_model, ctrl_model,
ctrl_tokenizer, knowledge_corpus,
knowledge_corpus_emb, args.add_separator)
if __name__ == "__main__":
main()
......@@ -30,6 +30,8 @@ from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_prompt_input_from_file
from megatron.text_generation_utils import generate_samples_line_by_line_input_from_file
from megatron.text_generation_utils import generate_samples_interactive
......@@ -70,6 +72,18 @@ def add_text_generate_args(parser):
group.add_argument("--recompute", action='store_true',
help='During generation recompute all attention '
'instead of using previously computed keys/values.')
group.add_argument('--spec-toks', type=str, default=None,
help='additional special tokens')
group.add_argument('--line-by-line', action="store_true",
help='generate samples line by line')
group.add_argument('--prompt', action="store_true",
help='generate samples based on prompting')
group.add_argument('--prompt-file', type=str, default="",
help='prompting file')
group.add_argument('--prompt-type', type=str, default="",
help='prompt type (context or keyphrase)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
return parser
......@@ -98,9 +112,15 @@ def main():
# Generate samples.
if args.num_samples == 0:
args.micro_batch_size = 1
if args.sample_input_file != None:
generate_samples_input_from_file(model)
args.micro_batch_size = 1
if args.line_by_line:
if args.prompt:
generate_samples_prompt_input_from_file(model)
else:
generate_samples_line_by_line_input_from_file(model)
else:
generate_samples_input_from_file(model)
else:
generate_samples_interactive(model)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment