Commit d7509658 authored by zihanl's avatar zihanl
Browse files

change folder name and add dialog training

parent 6f72a285
......@@ -6,6 +6,7 @@ import torch
import numpy as np
from megatron import get_tokenizer
from megatron import print_rank_0
def read_data(tokenizer, data_path, train_module):
"""read and tokenize dialog data"""
......@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module):
# only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
context = " [SEP] ".join(turns)
input_ids = tokenizer.tokenize(context)
# input_ids
for idx, turn in enumerate(turns):
if idx == 0:
input_ids = tokenizer.tokenize(turn)
else:
input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
......@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module):
turns = dialog_context.split(" [SEP] ")
last_turn = turns[-1]
# input_ids
if ctrl_code:
inputs = last_turn + " [CTRL] " + ctrl_code
input_ids = tokenizer.tokenize(last_turn)
ctrl_code_list = ctrl_code.split(" [CTRL] ")
for code in ctrl_code_list:
input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
else:
inputs = last_turn
outputs = ctrl_sent
input_ids = tokenizer.tokenize(last_turn)
input_ids = tokenizer.tokenize(inputs)
# output_ids
outputs = ctrl_sent
output_ids = tokenizer.tokenize(outputs)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
......@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len
self.max_seq_len = max_seq_len
self.pad_id = pad_id
self.eod_id = eod_id
......@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
assert len(input_ids) < self.max_seq_len, "Set a larger max_seq_len!"
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
text = input_ids + [self.pad_id] + output_ids + [self.eod_id]
......@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
return (train_dataset, valid_dataset, test_dataset)
return train_dataset, valid_dataset, test_dataset
......@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# reset attentino mask and position ids
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token_id]
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
# eod_index = eod_index.clone()
# # Loop through EOD indecies:
# prev_index = 0
# for j in range(eod_index.size()[0]):
# i = eod_index[j]
# # Mask attention loss.
# attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# # Reset positions.
# position_ids[b, (i + 1):] -= (i + 1 - prev_index)
# prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
......
......@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
parser = _add_dialog_ctrl_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -757,6 +758,8 @@ def _add_vit_args(parser):
def _add_dialog_ctrl_args(parser):
group = parser.add_argument_group(title="dialog control")
group.add_argument('--run-dialog', action='store_true',
help='run dialog modeling')
group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="",
......@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser):
help='dataset name (e.g., wizard_of_wikipedia)')
group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length')
group.add_argument('--spec_toks', type=str, default="[SEP],[CTRL],[PAD]",
group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]",
help='additional special tokens')
return parser
......@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=special_tokens, max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>']
if len(special_tokens) > 0:
if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.encoder['[PAD]']
if "[SEP]" in special_tokens:
self.sep_id = self.tokenizer.encoder['[SEP]']
self.sep_id = self.tokenizer.special_tokens['[SEP]']
if "[CTRL]" in special_tokens:
self.ctrl_id = self.tokenizer.encoder['[CTRL]']
self.ctrl_id = self.tokenizer.special_tokens['[CTRL]']
if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.special_tokens['[PAD]']
@property
def vocab_size(self):
......
......@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
......@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func):
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
# need to set train_samples to None
args.train_samples = None
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
......@@ -792,28 +793,50 @@ def build_train_valid_test_data_iterators(
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
if args.run_dialog:
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.iteration = 0
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
if args.run_dialog:
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
print_rank_0(' > datasets target sizes:')
train_size = len(train_ds)
valid_size = len(valid_ds)
test_size = len(test_ds)
print_rank_0(' train: {}'.format(train_size))
print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size))
args.train_iters = train_size // args.global_batch_size
args.eval_iters = valid_size // args.global_batch_size
args.test_iters = test_size // args.global_batch_size
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders.
train_dataloader = build_pretraining_data_loader(
......
......@@ -9,11 +9,11 @@ from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
from dialog_ctrl.dialog_dataset import build_train_valid_test_datasets
from dialogctrl.dialog_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel
from megatron.training import pretrain
# from megatron.utils import get_ltor_masks_and_position_ids
from dialog_ctrl.utils import get_ltor_attention_masks_and_position_ids
from dialogctrl.utils import get_ltor_attention_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
......@@ -52,7 +52,7 @@ def get_batch(data_iterator):
loss_mask = data_b['loss_mask'].float()
# Get the attention_mask and postition ids.
attention_masks, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
attention_mask, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids
......@@ -86,7 +86,7 @@ def forward_step(data_iterator, model):
def train_valid_test_datasets_provider():
"""Build train, valid, and test datasets for control module"""
"""Build train, valid, and test datasets for dialog/control module"""
args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
......@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider():
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment