"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "3250d3df168c956389bd16956aa458ce111570d0"
Commit d7509658 authored by zihanl's avatar zihanl
Browse files

change folder name and add dialog training

parent 6f72a285
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import print_rank_0
def read_data(tokenizer, data_path, train_module): def read_data(tokenizer, data_path, train_module):
"""read and tokenize dialog data""" """read and tokenize dialog data"""
...@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module): ...@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module):
# only take the last three turns in the dialog context # only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ") turns = dialog_context.split(" [SEP] ")
turns = turns[-3:] turns = turns[-3:]
context = " [SEP] ".join(turns)
input_ids = tokenizer.tokenize(context) # input_ids
for idx, turn in enumerate(turns):
if idx == 0:
input_ids = tokenizer.tokenize(turn)
else:
input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# output_ids
output_ids = tokenizer.tokenize(response) output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids}) data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control": elif train_module == "control":
...@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module): ...@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module):
turns = dialog_context.split(" [SEP] ") turns = dialog_context.split(" [SEP] ")
last_turn = turns[-1] last_turn = turns[-1]
# input_ids
if ctrl_code: if ctrl_code:
inputs = last_turn + " [CTRL] " + ctrl_code input_ids = tokenizer.tokenize(last_turn)
ctrl_code_list = ctrl_code.split(" [CTRL] ")
for code in ctrl_code_list:
input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
else: else:
inputs = last_turn input_ids = tokenizer.tokenize(last_turn)
outputs = ctrl_sent
input_ids = tokenizer.tokenize(inputs) # output_ids
outputs = ctrl_sent
output_ids = tokenizer.tokenize(outputs) output_ids = tokenizer.tokenize(outputs)
data_list.append({"input_ids": input_ids, "output_ids": output_ids}) data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else: else:
...@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset): ...@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, pad_id, eod_id): def __init__(self, data, max_seq_len, pad_id, eod_id):
# need to deal with padding, label masking # need to deal with padding, label masking
self.data = data self.data = data
self.max_seq_len self.max_seq_len = max_seq_len
self.pad_id = pad_id self.pad_id = pad_id
self.eod_id = eod_id self.eod_id = eod_id
...@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset): ...@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
data_dict = self.data[idx] data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"] input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
assert len(input_ids) < self.max_seq_len, "Set a larger max_seq_len!" assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1 # length_of_loss_mask == length_of_text - 1
text = input_ids + [self.pad_id] + output_ids + [self.eod_id] text = input_ids + [self.pad_id] + output_ids + [self.eod_id]
...@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max ...@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id) valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id) test_dataset = ControlDialogDataset(test_data_list, max_seq_len, tokenizer.pad_id, tokenizer.eod_id)
return (train_dataset, valid_dataset, test_dataset) return train_dataset, valid_dataset, test_dataset
...@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id): ...@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# reset attentino mask and position ids # reset attentino mask and position ids
# Loop through the batches: # Loop through the batches:
for b in range(micro_batch_size): # for b in range(micro_batch_size):
# Find indecies where EOD token is. # # Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token_id] # eod_index = position_ids[b, data[b] == eod_token_id]
eod_index = eod_index.clone() # eod_index = eod_index.clone()
# Loop through EOD indecies: # # Loop through EOD indecies:
prev_index = 0 # prev_index = 0
for j in range(eod_index.size()[0]): # for j in range(eod_index.size()[0]):
i = eod_index[j] # i = eod_index[j]
# Mask attention loss. # # Mask attention loss.
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 # attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions. # # Reset positions.
position_ids[b, (i + 1):] -= (i + 1 - prev_index) # position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1 # prev_index = i + 1
# Convert attention mask to binary: # Convert attention mask to binary:
attention_mask = (attention_mask < 0.5) attention_mask = (attention_mask < 0.5)
......
...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_dialog_ctrl_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -757,6 +758,8 @@ def _add_vit_args(parser): ...@@ -757,6 +758,8 @@ def _add_vit_args(parser):
def _add_dialog_ctrl_args(parser): def _add_dialog_ctrl_args(parser):
group = parser.add_argument_group(title="dialog control") group = parser.add_argument_group(title="dialog control")
group.add_argument('--run-dialog', action='store_true',
help='run dialog modeling')
group.add_argument('--train-module', type=str, default="", group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)') help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="", group.add_argument('--data-folder', type=str, default="",
...@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser): ...@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser):
help='dataset name (e.g., wizard_of_wikipedia)') help='dataset name (e.g., wizard_of_wikipedia)')
group.add_argument('--max-seq-len', type=int, default=1024, group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length') help='maximum sequence length')
group.add_argument('--spec_toks', type=str, default="[SEP],[CTRL],[PAD]", group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]",
help='additional special tokens') help='additional special tokens')
return parser return parser
...@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=special_tokens, max_len=None) special_tokens=special_tokens, max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>'] self.eod_id = self.tokenizer.encoder['<|endoftext|>']
if len(special_tokens) > 0: if len(special_tokens) > 0:
if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.encoder['[PAD]']
if "[SEP]" in special_tokens: if "[SEP]" in special_tokens:
self.sep_id = self.tokenizer.encoder['[SEP]'] self.sep_id = self.tokenizer.special_tokens['[SEP]']
if "[CTRL]" in special_tokens: if "[CTRL]" in special_tokens:
self.ctrl_id = self.tokenizer.encoder['[CTRL]'] self.ctrl_id = self.tokenizer.special_tokens['[CTRL]']
if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.special_tokens['[PAD]']
@property @property
def vocab_size(self): def vocab_size(self):
......
...@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving ...@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.utils import report_memory from megatron.utils import report_memory
def print_datetime(string): def print_datetime(string):
"""Note that this call will sync across all ranks.""" """Note that this call will sync across all ranks."""
torch.distributed.barrier() torch.distributed.barrier()
...@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func):
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
# need to set train_samples to None
args.train_samples = None
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -792,28 +793,50 @@ def build_train_valid_test_data_iterators( ...@@ -792,28 +793,50 @@ def build_train_valid_test_data_iterators(
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size args.eval_iters * args.global_batch_size
if args.run_dialog:
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.iteration = 0
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
if args.run_dialog:
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider()
print_rank_0(' > datasets target sizes:')
train_size = len(train_ds)
valid_size = len(valid_ds)
test_size = len(test_ds)
print_rank_0(' train: {}'.format(train_size))
print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size))
args.train_iters = train_size // args.global_batch_size
args.eval_iters = valid_size // args.global_batch_size
args.test_iters = test_size // args.global_batch_size
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
else: else:
train_samples = args.train_iters * args.global_batch_size # Number of train/valid/test samples.
eval_iters = (args.train_iters // args.eval_interval + 1) * \ if args.train_samples:
args.eval_iters train_samples = args.train_samples
test_iters = args.eval_iters else:
train_val_test_num_samples = [train_samples, train_samples = args.train_iters * args.global_batch_size
eval_iters * args.global_batch_size, eval_iters = (args.train_iters // args.eval_interval + 1) * \
test_iters * args.global_batch_size] args.eval_iters
print_rank_0(' > datasets target sizes (minimum size):') test_iters = args.eval_iters
print_rank_0(' train: {}'.format(train_val_test_num_samples[0])) train_val_test_num_samples = [train_samples,
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) eval_iters * args.global_batch_size,
print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
# Build the datasets. print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
train_val_test_num_samples) print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders. # Build dataloders.
train_dataloader = build_pretraining_data_loader( train_dataloader = build_pretraining_data_loader(
......
...@@ -9,11 +9,11 @@ from megatron import get_timers ...@@ -9,11 +9,11 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets # from megatron.data.gpt_dataset import build_train_valid_test_datasets
from dialog_ctrl.dialog_dataset import build_train_valid_test_datasets from dialogctrl.dialog_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import pretrain from megatron.training import pretrain
# from megatron.utils import get_ltor_masks_and_position_ids # from megatron.utils import get_ltor_masks_and_position_ids
from dialog_ctrl.utils import get_ltor_attention_masks_and_position_ids from dialogctrl.utils import get_ltor_attention_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
...@@ -52,7 +52,7 @@ def get_batch(data_iterator): ...@@ -52,7 +52,7 @@ def get_batch(data_iterator):
loss_mask = data_b['loss_mask'].float() loss_mask = data_b['loss_mask'].float()
# Get the attention_mask and postition ids. # Get the attention_mask and postition ids.
attention_masks, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id) attention_mask, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
...@@ -86,7 +86,7 @@ def forward_step(data_iterator, model): ...@@ -86,7 +86,7 @@ def forward_step(data_iterator, model):
def train_valid_test_datasets_provider(): def train_valid_test_datasets_provider():
"""Build train, valid, and test datasets for control module""" """Build train, valid, and test datasets for dialog/control module"""
args = get_args() args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module) print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
...@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider(): ...@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider():
seed=args.seed) seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module) print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment