Commit 6f72a285 authored by zihanl's avatar zihanl
Browse files

add dialog dataset and special tokens in tokenizer

parent 1f7cfdf6
"""Build Dataset for Controllable Coversational Model"""
import os
import torch
import numpy as np
from megatron import get_tokenizer
def read_data(tokenizer, data_path, train_module):
"""read and tokenize dialog data"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
length_split = len(splits)
assert length_split == 2 or length_split == 3 or length_split == 4
if train_module == "dialog":
dialog_context = splits[0]
response = splits[-1]
# only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
context = " [SEP] ".join(turns)
input_ids = tokenizer.tokenize(context)
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
if length_split == 2:
continue
dialog_context = splits[0]
ctrl_sent = splits[-2]
ctrl_code = splits[1] if length_split == 4 else None
turns = dialog_context.split(" [SEP] ")
last_turn = turns[-1]
if ctrl_code:
inputs = last_turn + " [CTRL] " + ctrl_code
else:
inputs = last_turn
outputs = ctrl_sent
input_ids = tokenizer.tokenize(inputs)
output_ids = tokenizer.tokenize(outputs)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
np.random.shuffle(data)
return data
class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len
self.pad_id = pad_id
self.eod_id = eod_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
assert len(input_ids) < self.max_seq_len, "Set a larger max_seq_len!"
# length_of_loss_mask == length_of_text - 1
text = input_ids + [self.pad_id] + output_ids + [self.eod_id]
loss_mask = [0]*len(input_ids) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len:
text = text[:self.max_seq_len]
loss_mask = loss_mask[:self.max_seq_len-1]
else:
text += [self.pad_id] * (self.max_seq_len - text_len)
loss_mask += [0] * (self.max_seq_len - text_len)
return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max_seq_len, seed):
"""Build train, valid, and test datasets."""
dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer = get_tokenizer()
train_data_list = read_data(tokenizer, train_data_path, train_module)
valid_data_list = read_data(tokenizer, valid_data_path, train_module)
test_data_list = read_data(tokenizer, test_data_path, train_module)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid, and test datasets
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
return (train_dataset, valid_dataset, test_dataset)
import torch
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# reset attentino mask and position ids
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token_id]
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, position_ids
\ No newline at end of file
...@@ -752,3 +752,20 @@ def _add_vit_args(parser): ...@@ -752,3 +752,20 @@ def _add_vit_args(parser):
help='patch dimension used in vit') help='patch dimension used in vit')
return parser return parser
def _add_dialog_ctrl_args(parser):
group = parser.add_argument_group(title="dialog control")
group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="",
help='data folder (path of the data folder)')
group.add_argument('--dataset-name', type=str, default="",
help='dataset name (e.g., wizard_of_wikipedia)')
group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length')
group.add_argument('--spec_toks', type=str, default="[SEP],[CTRL],[PAD]",
help='additional special tokens')
return parser
...@@ -40,7 +40,7 @@ def build_tokenizer(args): ...@@ -40,7 +40,7 @@ def build_tokenizer(args):
vocab_extra_ids=args.vocab_extra_ids) vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=args.spec_toks)
else: else:
raise NotImplementedError('{} tokenizer is not ' raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
...@@ -260,13 +260,25 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -260,13 +260,25 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
class _GPT2BPETokenizer(AbstractTokenizer): class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer.""" """Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file): def __init__(self, vocab_file, merge_file, special_tokens=None):
name = 'GPT2 BPE' name = 'GPT2 BPE'
super().__init__(name) super().__init__(name)
if special_tokens is not None:
# for controllable dialog, special_tokens: "[SEP],[CTRL],[PAD]"
special_tokens = special_tokens.split(",")
else:
special_tokens = []
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=[], max_len=None) special_tokens=special_tokens, max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>'] self.eod_id = self.tokenizer.encoder['<|endoftext|>']
if len(special_tokens) > 0:
if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.encoder['[PAD]']
if "[SEP]" in special_tokens:
self.sep_id = self.tokenizer.encoder['[SEP]']
if "[CTRL]" in special_tokens:
self.ctrl_id = self.tokenizer.encoder['[CTRL]']
@property @property
def vocab_size(self): def vocab_size(self):
......
"""Train dialogue model based on GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
from dialog_ctrl.dialog_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel
from megatron.training import pretrain
# from megatron.utils import get_ltor_masks_and_position_ids
from dialog_ctrl.utils import get_ltor_attention_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text', 'loss_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
loss_mask = data_b['loss_mask'].float()
# Get the attention_mask and postition ids.
attention_masks, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider():
"""Build train, valid, and test datasets for control module"""
args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_folder=args.data_folder,
dataset_name=args.dataset_name,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment