Commit 3c363d57 authored by zihanl's avatar zihanl
Browse files

update dialog dataset and training

parent d7509658
...@@ -20,11 +20,16 @@ def read_data(tokenizer, data_path, train_module): ...@@ -20,11 +20,16 @@ def read_data(tokenizer, data_path, train_module):
assert length_split == 2 or length_split == 3 or length_split == 4 assert length_split == 2 or length_split == 3 or length_split == 4
if train_module == "dialog": if train_module == "dialog":
# if length_split == 2:
# continue
dialog_context = splits[0] dialog_context = splits[0]
if length_split > 2:
ctrl_sent = splits[-2]
response = splits[-1] response = splits[-1]
# only take the last three turns in the dialog context # only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ") turns = dialog_context.split(" [SEP] ")
turns = turns[-3:] # turns = turns[-3:]
# input_ids # input_ids
for idx, turn in enumerate(turns): for idx, turn in enumerate(turns):
...@@ -33,6 +38,10 @@ def read_data(tokenizer, data_path, train_module): ...@@ -33,6 +38,10 @@ def read_data(tokenizer, data_path, train_module):
else: else:
input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn)) input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
if length_split > 2:
# when there is control sentence, add it into the input_ids
input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
# output_ids # output_ids
output_ids = tokenizer.tokenize(response) output_ids = tokenizer.tokenize(response)
...@@ -65,7 +74,7 @@ def read_data(tokenizer, data_path, train_module): ...@@ -65,7 +74,7 @@ def read_data(tokenizer, data_path, train_module):
else: else:
raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))") raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))")
return data_list return data_list
...@@ -78,10 +87,11 @@ def data_shuffle(data, seed): ...@@ -78,10 +87,11 @@ def data_shuffle(data, seed):
class ControlDialogDataset(torch.utils.data.Dataset): class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id): def __init__(self, data, max_seq_len, sep_id, pad_id, eod_id):
# need to deal with padding, label masking # need to deal with padding, label masking
self.data = data self.data = data
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.sep_id = sep_id
self.pad_id = pad_id self.pad_id = pad_id
self.eod_id = eod_id self.eod_id = eod_id
...@@ -95,16 +105,16 @@ class ControlDialogDataset(torch.utils.data.Dataset): ...@@ -95,16 +105,16 @@ class ControlDialogDataset(torch.utils.data.Dataset):
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!" assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1 # length_of_loss_mask == length_of_text - 1
text = input_ids + [self.pad_id] + output_ids + [self.eod_id] text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
loss_mask = [0]*len(input_ids) + [1]*(len(output_ids)+1) loss_mask = [0]*len(input_ids) + [1]*(len(output_ids)+1)
text_len = len(text) text_len = len(text)
if text_len > self.max_seq_len: if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len] text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len-1] loss_mask = loss_mask[:self.max_seq_len]
else: else:
text += [self.pad_id] * (self.max_seq_len - text_len) text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len - text_len) loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)} return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)}
...@@ -127,8 +137,8 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max ...@@ -127,8 +137,8 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
train_data_list = data_shuffle(train_data_list, seed) train_data_list = data_shuffle(train_data_list, seed)
# build train, valid, and test datasets # build train, valid, and test datasets
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id) train_dataset = ControlDialogDataset(train_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id) valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id) test_dataset = ControlDialogDataset(test_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset, test_dataset return train_dataset, valid_dataset, test_dataset
import torch import torch
from megatron import print_rank_0
def get_ltor_attention_masks_and_position_ids(data, eod_token_id): def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model.""" """Build attention masks and position id for left to right model."""
...@@ -10,12 +10,19 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id): ...@@ -10,12 +10,19 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# Attention mask # Attention mask
attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length) attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids. # Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data) position_ids = position_ids.unsqueeze(0).expand_as(data)
# reset attentino mask and position ids # # reset attentino mask and position ids
# Loop through the batches: # # Loop through the batches:
# for b in range(micro_batch_size): # for b in range(micro_batch_size):
# # Find indecies where EOD token is. # # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id] # eod_index = position_ids[b, data[b] == eod_token_id]
......
...@@ -760,6 +760,8 @@ def _add_dialog_ctrl_args(parser): ...@@ -760,6 +760,8 @@ def _add_dialog_ctrl_args(parser):
group.add_argument('--run-dialog', action='store_true', group.add_argument('--run-dialog', action='store_true',
help='run dialog modeling') help='run dialog modeling')
group.add_argument('--num-epoch', type=int, default=30,
help='number of epoches to train the model')
group.add_argument('--train-module', type=str, default="", group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)') help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="", group.add_argument('--data-folder', type=str, default="",
......
...@@ -344,19 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -344,19 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0(f' checkpoint version {checkpoint_version}') print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version) fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer. if not args.run_dialog:
if not release and not args.finetune and not args.no_load_optim: # Original pre-train GPT setting
try: # Optimizer.
if optimizer is not None: if not release and not args.finetune and not args.no_load_optim:
optimizer.load_state_dict(state_dict['optimizer']) try:
if lr_scheduler is not None: if optimizer is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler']) optimizer.load_state_dict(state_dict['optimizer'])
except KeyError: if lr_scheduler is not None:
print_rank_0('Unable to load optimizer from checkpoint {}. ' lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
'Specify --no-load-optim or --finetune to prevent ' except KeyError:
'attempting to load the optimizer state, ' print_rank_0('Unable to load optimizer from checkpoint {}. '
'exiting ...'.format(checkpoint_name)) 'Specify --no-load-optim or --finetune to prevent '
sys.exit() 'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
# rng states. # rng states.
if not release and not args.finetune and not args.no_load_rng: if not release and not args.finetune and not args.no_load_rng:
......
...@@ -138,27 +138,57 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -138,27 +138,57 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if not args.run_dialog:
iteration = train(forward_step_func, # original pre-training for GPT
model, optimizer, lr_scheduler, if args.do_train and args.train_iters > 0:
train_data_iterator, valid_data_iterator) iteration = train(forward_step_func,
print_datetime('after training is done') model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
if args.do_valid: print_datetime('after training is done')
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, if args.do_valid:
valid_data_iterator, model, prefix = 'the end of training for val data'
iteration, False) evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
if args.save and iteration != 0: iteration, False)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.save and iteration != 0:
if args.do_test: save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Run on test data.
prefix = 'the end of training for test data' if args.do_test:
evaluate_and_print_results(prefix, forward_step_func, # Run on test data.
test_data_iterator, model, prefix = 'the end of training for test data'
0, True) evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
else:
# training for dialog/control model
timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
for e in range(args.num_epoch):
print_rank_0('> training on epoch %d' % (e+1))
if args.do_train and args.train_iters > 0:
iteration += train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
if e >= 8 and e <= 13 and args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
def update_train_iters(args): def update_train_iters(args):
...@@ -611,7 +641,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -611,7 +641,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
timers('interval-time').start() if not args.run_dialog:
timers('interval-time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
...@@ -813,9 +845,10 @@ def build_train_valid_test_data_iterators( ...@@ -813,9 +845,10 @@ def build_train_valid_test_data_iterators(
print_rank_0(' validation: {}'.format(valid_size)) print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size)) print_rank_0(' test: {}'.format(test_size))
args.train_iters = train_size // args.global_batch_size batch_size = args.micro_batch_size * args.data_parallel_size
args.eval_iters = valid_size // args.global_batch_size args.train_iters = train_size // batch_size + 1
args.test_iters = test_size // args.global_batch_size args.eval_iters = valid_size // batch_size + 1
args.test_iters = test_size // batch_size + 1
else: else:
# Number of train/valid/test samples. # Number of train/valid/test samples.
......
...@@ -53,7 +53,7 @@ def get_batch(data_iterator): ...@@ -53,7 +53,7 @@ def get_batch(data_iterator):
# Get the attention_mask and postition ids. # Get the attention_mask and postition ids.
attention_mask, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id) attention_mask, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment