Commit 3c363d57 authored by zihanl's avatar zihanl
Browse files

update dialog dataset and training

parent d7509658
......@@ -20,11 +20,16 @@ def read_data(tokenizer, data_path, train_module):
assert length_split == 2 or length_split == 3 or length_split == 4
if train_module == "dialog":
# if length_split == 2:
# continue
dialog_context = splits[0]
if length_split > 2:
ctrl_sent = splits[-2]
response = splits[-1]
# only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
# turns = turns[-3:]
# input_ids
for idx, turn in enumerate(turns):
......@@ -33,6 +38,10 @@ def read_data(tokenizer, data_path, train_module):
else:
input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
if length_split > 2:
# when there is control sentence, add it into the input_ids
input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
# output_ids
output_ids = tokenizer.tokenize(response)
......@@ -65,7 +74,7 @@ def read_data(tokenizer, data_path, train_module):
else:
raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))")
return data_list
......@@ -78,10 +87,11 @@ def data_shuffle(data, seed):
class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id):
def __init__(self, data, max_seq_len, sep_id, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len = max_seq_len
self.sep_id = sep_id
self.pad_id = pad_id
self.eod_id = eod_id
......@@ -95,16 +105,16 @@ class ControlDialogDataset(torch.utils.data.Dataset):
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
text = input_ids + [self.pad_id] + output_ids + [self.eod_id]
text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
loss_mask = [0]*len(input_ids) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len:
text = text[:self.max_seq_len]
loss_mask = loss_mask[:self.max_seq_len-1]
if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len]
else:
text += [self.pad_id] * (self.max_seq_len - text_len)
loss_mask += [0] * (self.max_seq_len - text_len)
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)}
......@@ -127,8 +137,8 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid, and test datasets
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset, test_dataset
import torch
from megatron import print_rank_0
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
......@@ -10,12 +10,19 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# Attention mask
attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# reset attentino mask and position ids
# Loop through the batches:
# # reset attentino mask and position ids
# # Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
......
......@@ -760,6 +760,8 @@ def _add_dialog_ctrl_args(parser):
group.add_argument('--run-dialog', action='store_true',
help='run dialog modeling')
group.add_argument('--num-epoch', type=int, default=30,
help='number of epoches to train the model')
group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="",
......
......@@ -344,19 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
if not args.run_dialog:
# Original pre-train GPT setting
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
......
......@@ -138,27 +138,57 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...')
iteration = 0
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
if not args.run_dialog:
# original pre-training for GPT
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
else:
# training for dialog/control model
timers('interval-time').start() # start timers('interval-time') here to avoid it from starting multiple times
for e in range(args.num_epoch):
print_rank_0('> training on epoch %d' % (e+1))
if args.do_train and args.train_iters > 0:
iteration += train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
if e >= 8 and e <= 13 and args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
def update_train_iters(args):
......@@ -611,7 +641,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
timers('interval-time').start()
if not args.run_dialog:
timers('interval-time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
......@@ -813,9 +845,10 @@ def build_train_valid_test_data_iterators(
print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size))
args.train_iters = train_size // args.global_batch_size
args.eval_iters = valid_size // args.global_batch_size
args.test_iters = test_size // args.global_batch_size
batch_size = args.micro_batch_size * args.data_parallel_size
args.train_iters = train_size // batch_size + 1
args.eval_iters = valid_size // batch_size + 1
args.test_iters = test_size // batch_size + 1
else:
# Number of train/valid/test samples.
......
......@@ -53,7 +53,7 @@ def get_batch(data_iterator):
# Get the attention_mask and postition ids.
attention_mask, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment