"src/vscode:/vscode.git/clone" did not exist on "926daa30f977a75c01b835405c1878730388cb94"
Commit a87777bf authored by zihanl's avatar zihanl
Browse files

delete finetune part

parent 5f4e63fc
"""Build Dataset for Controllable Coversational Model"""
import os
import torch
import numpy as np
from megatron import get_tokenizer
from megatron import print_rank_0
def read_data_for_finetuning(tokenizer, data_path, module):
"""
Data Format: topic \t dialog context \t knowledge \t response.
"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.rstrip()
splits = line.split("\t")
assert len(splits) == 4
topic = splits[0].split(" [CTRL] ")[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
if module == "response":
# input_ids
input_ids = tokenizer.tokenize("( " + topic + " )")
if knowledge != "no_passages_used":
input_ids.extend(tokenizer.tokenize("( " + knowledge + " )")[:256])
for turn in turns:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(":"))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif module == "knowledge":
# skip example without knowledge sentences
if knowledge == "no_passages_used":
continue
input_ids = []
input_ids.extend(tokenizer.tokenize("( " + topic + " )"))
for turn in turns:
turn = "<< " + turn + " >>"
input_ids.extend(tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(":"))
output_ids = tokenizer.tokenize(knowledge)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct module name! " \
"(either dialog or cnotrol))")
return data_list
def read_data_for_prompting(tokenizer, test_data_path, prompt_file,
module, num_prompt_examples, dynamic_prompt):
# get prompts
if dynamic_prompt:
import json
prompt_examples_dict = {}
with open(prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt_examples = prompt_examples[:num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[topic] = prompt
else:
with open(prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
data_list = []
with open(test_data_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0].split(" [CTRL] ")[0]
turns = splits[1].split(" [SEP] ")[-3:]
last_turn = turns[-1]
ctrl_sent = splits[2]
response = splits[3]
if dynamic_prompt:
prompt = prompt_examples_dict[topic]
if module == "response":
# input seq
input_seq = prompt
input_seq += "Topic: " + topic + ". "
input_seq += "User says: " + last_turn + " "
input_seq += "We know that: " + ctrl_sent + " "
input_seq += "System replies:"
# output seq
output_seq = response
input_ids = tokenizer.tokenize(input_seq)
output_ids = tokenizer.tokenize(output_seq)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif module == "knowledge":
# input seq
input_seq = prompt
input_seq += "( " + last_turn + " ) " + topic + " =>"
# output seq
output_seq = ctrl_sent
input_ids = tokenizer.tokenize(input_seq)
output_ids = tokenizer.tokenize(output_seq)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct module name! " \
"(either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
np.random.shuffle(data)
return data
class KnwlDialoDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len = max_seq_len
self.pad_id = pad_id
self.eod_id = eod_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
text = input_ids + output_ids + [self.eod_id]
loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len]
else:
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), \
"loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_datasets(train_data_path, valid_data_path, module,
max_seq_len, seed):
"""Build train, valid, and test datasets."""
tokenizer = get_tokenizer()
train_data_list = read_data_for_finetuning(tokenizer, train_data_path, module)
valid_data_list = read_data_for_finetuning(tokenizer, valid_data_path, module)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid datasets
train_dataset = KnwlDialoDataset(train_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
valid_dataset = KnwlDialoDataset(valid_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset
def build_test_dataset(test_data_path, module, max_seq_len):
tokenizer = get_tokenizer()
test_data_list = read_data_for_finetuning(tokenizer, test_data_path, module)
test_dataset = KnwlDialoDataset(test_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return test_dataset
def build_test_dataset_for_prompting(test_data_path, prompt_file, module, max_seq_len,
num_prompt_examples, dynamic_prompt):
tokenizer = get_tokenizer()
test_data_list = read_data_for_prompting(tokenizer, test_data_path, prompt_file, module, \
num_prompt_examples, dynamic_prompt)
test_dataset = KnwlDialoDataset(test_data_list,
max_seq_len,
pad_id=tokenizer.pad_id,
eod_id=tokenizer.eod_id)
return test_dataset
"""Finetuning a pretrained language model for knowledge/response generation"""
import torch
from functools import partial
from megatron import mpu
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import evaluate_and_print_results
from megatron.training import get_model
from megatron.utils import average_losses_across_data_parallel_group
from megatron.initialize import initialize_megatron
from tasks.finetune_utils import finetune
from tasks.knwl_dialo.data import build_train_valid_datasets
from tasks.knwl_dialo.utils import get_ltor_attention_masks_and_position_ids
from tasks.knwl_dialo.utils import get_token_stream
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def train_valid_datasets_provider():
"""Build train, valid, and test datasets for dialog/control module"""
args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.module)
train_ds, valid_ds = build_train_valid_datasets(
train_data_path=args.train_data_path,
valid_data_path=args.test_data_path,
module=args.module,
max_seq_len=args.seq_length,
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.module)
print_rank_0('> Train size: %d' % len(train_ds))
print_rank_0('> Validation size: %d' % len(valid_ds))
args.eval_interval = len(train_ds) // args.global_batch_size
print_rank_0('> evaluation interval: %d' % args.eval_interval)
args.eval_iters = len(valid_ds) // args.global_batch_size
print_rank_0('> evaluation iteration: %d' % args.eval_iters)
return train_ds, valid_ds
def process_batch(batch):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text', 'loss_mask']
datatype = torch.int64
data_b = mpu.broadcast_data(keys, batch, datatype)
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
loss_mask = data_b['loss_mask'].float()
# Get the attention_mask and postition ids.
attention_mask, position_ids = \
get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(batch, model):
"""Forward step."""
args = get_args()
timers = get_timers()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, labels, loss_mask, attention_mask, position_ids = process_batch(batch_)
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def generate_samples_input_from_file(model):
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
input_pos = 0
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
context_count = 0
model.eval()
# start the generation process
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
raw_text = all_raw_text[input_pos]
input_pos += 1
raw_text_len = len(raw_text)
context_tokens = tokenizer.tokenize(raw_text)
else:
context_tokens = tokenizer.tokenize("EMPTY TEXT")
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
# get the generation outputs
token_stream = get_token_stream(model, [context_tokens])
for _, decode_tokens in enumerate(token_stream):
pass
# write the generation to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
if "\r" in trim_decode_tokens:
trim_decode_tokens = trim_decode_tokens.replace("\r", "")
if "\n" in trim_decode_tokens:
trim_decode_tokens = trim_decode_tokens.replace("\n", "")
fname_out.write(trim_decode_tokens)
fname_out.write("\n")
raw_text = None
context_count += 1
if input_pos == input_count:
return
def run_generation(model_provider):
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# run generation
generate_samples_input_from_file(model)
def main():
args = get_args()
if "FINETUNE" in args.task:
# finetune
finetune(train_valid_datasets_provider, model_provider, \
forward_step=forward_step)
else:
# generate
run_generation(model_provider)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment