Commit 5aa3d5e3 authored by zihanl's avatar zihanl
Browse files

add dialctrl into tasks

parent 3c363d57
...@@ -29,18 +29,22 @@ def read_data(tokenizer, data_path, train_module): ...@@ -29,18 +29,22 @@ def read_data(tokenizer, data_path, train_module):
response = splits[-1] response = splits[-1]
# only take the last three turns in the dialog context # only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ") turns = dialog_context.split(" [SEP] ")
# turns = turns[-3:] turns = turns[-3:]
# input_ids # input_ids
for idx, turn in enumerate(turns): for idx, turn in enumerate(turns):
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
if idx == 0: if idx == 0:
input_ids = tokenizer.tokenize(turn) input_ids = tokenizer.tokenize(turn)
else: else:
input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn)) # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(turn))
if length_split > 2: if length_split > 2:
# when there is control sentence, add it into the input_ids # when there is control sentence, add it into the input_ids
input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent)) # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids.extend(tokenizer.tokenize("( " + ctrl_sent + " ) ."))
# output_ids # output_ids
output_ids = tokenizer.tokenize(response) output_ids = tokenizer.tokenize(response)
...@@ -55,16 +59,35 @@ def read_data(tokenizer, data_path, train_module): ...@@ -55,16 +59,35 @@ def read_data(tokenizer, data_path, train_module):
ctrl_code = splits[1] if length_split == 4 else None ctrl_code = splits[1] if length_split == 4 else None
turns = dialog_context.split(" [SEP] ") turns = dialog_context.split(" [SEP] ")
last_turn = turns[-1] # last_turn = turns[-1]
# input_ids # turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
input_ids = []
if ctrl_code: if ctrl_code:
input_ids = tokenizer.tokenize(last_turn)
ctrl_code_list = ctrl_code.split(" [CTRL] ") ctrl_code_list = ctrl_code.split(" [CTRL] ")
for code in ctrl_code_list: for code in ctrl_code_list:
input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code)) input_ids.extend(tokenizer.tokenize("( " + code + " )"))
else:
input_ids = tokenizer.tokenize(last_turn) turns = turns[-3:]
for turn in turns:
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
input_ids.extend(tokenizer.tokenize(turn))
# output_ids # output_ids
outputs = ctrl_sent outputs = ctrl_sent
...@@ -105,8 +128,9 @@ class ControlDialogDataset(torch.utils.data.Dataset): ...@@ -105,8 +128,9 @@ class ControlDialogDataset(torch.utils.data.Dataset):
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!" assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1 # length_of_loss_mask == length_of_text - 1
text = input_ids + [self.sep_id] + output_ids + [self.eod_id] # text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
loss_mask = [0]*len(input_ids) + [1]*(len(output_ids)+1) text = input_ids + output_ids + [self.eod_id]
loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)
text_len = len(text) text_len = len(text)
if text_len > self.max_seq_len+1: if text_len > self.max_seq_len+1:
......
...@@ -41,7 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,7 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_biencoder_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_dialog_ctrl_args(parser) # parser = _add_dialog_ctrl_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -755,22 +755,22 @@ def _add_vit_args(parser): ...@@ -755,22 +755,22 @@ def _add_vit_args(parser):
return parser return parser
def _add_dialog_ctrl_args(parser): # def _add_dialog_ctrl_args(parser):
group = parser.add_argument_group(title="dialog control") # group = parser.add_argument_group(title="dialog control")
group.add_argument('--run-dialog', action='store_true', # group.add_argument('--run-dialog', action='store_true',
help='run dialog modeling') # help='run dialog modeling')
group.add_argument('--num-epoch', type=int, default=30, # group.add_argument('--num-epoch', type=int, default=30,
help='number of epoches to train the model') # help='number of epoches to train the model')
group.add_argument('--train-module', type=str, default="", # group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)') # help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="", # group.add_argument('--data-folder', type=str, default="",
help='data folder (path of the data folder)') # help='data folder (path of the data folder)')
group.add_argument('--dataset-name', type=str, default="", # group.add_argument('--dataset-name', type=str, default="",
help='dataset name (e.g., wizard_of_wikipedia)') # help='dataset name (e.g., wizard_of_wikipedia)')
group.add_argument('--max-seq-len', type=int, default=1024, # group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length') # help='maximum sequence length')
group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]", # group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]",
help='additional special tokens') # help='additional special tokens')
return parser # return parser
...@@ -344,21 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -344,21 +344,21 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0(f' checkpoint version {checkpoint_version}') print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version) fix_query_key_value_ordering(model, checkpoint_version)
if not args.run_dialog: # if not args.run_dialog:
# Original pre-train GPT setting # Original pre-train GPT setting
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
if optimizer is not None: if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer']) optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler']) lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent ' 'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, ' 'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
# rng states. # rng states.
if not release and not args.finetune and not args.no_load_rng: if not release and not args.finetune and not args.no_load_rng:
......
...@@ -247,6 +247,7 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -247,6 +247,7 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs = 1 terminate_runs = 1
else: else:
context_tokens = tokenizer.tokenize(raw_text) context_tokens = tokenizer.tokenize(raw_text)
# context_tokens = context_tokens + [tokenizer.sep_id]
context_length = len(context_tokens) context_length = len(context_tokens)
if context_length >= (args.seq_length // 2): if context_length >= (args.seq_length // 2):
...@@ -299,9 +300,14 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -299,9 +300,14 @@ def generate_samples_interactive(model, print_frequency=24):
print("\nContext:", raw_text, flush=True) print("\nContext:", raw_text, flush=True)
decode_tokens, _ = decode_tokens decode_tokens, _ = decode_tokens
# print("tokenzied inputs:", tokenizer.tokenize(raw_text))
# print("decode_tokens:", decode_tokens)
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:] decode_tokens)[raw_text_len:]
# trim_decode_tokens = tokenizer.detokenize(
# decode_tokens[context_length:])
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
if mpu.is_pipeline_first_stage() \ if mpu.is_pipeline_first_stage() \
...@@ -314,6 +320,9 @@ def generate_samples_interactive(model, print_frequency=24): ...@@ -314,6 +320,9 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens = decode_tokens[0].cpu().numpy().tolist() decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize( trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:] decode_tokens)[raw_text_len:]
# print("decode_tokens:", decode_tokens)
# trim_decode_tokens = tokenizer.detokenize(
# decode_tokens[context_length:])
print("\nMegatron-LM:", trim_decode_tokens, flush=True) print("\nMegatron-LM:", trim_decode_tokens, flush=True)
input("\nPress Enter to continue >>>") input("\nPress Enter to continue >>>")
......
...@@ -41,6 +41,7 @@ def build_tokenizer(args): ...@@ -41,6 +41,7 @@ def build_tokenizer(args):
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=args.spec_toks) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=args.spec_toks)
# tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
else: else:
raise NotImplementedError('{} tokenizer is not ' raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
...@@ -273,7 +274,7 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -273,7 +274,7 @@ class _GPT2BPETokenizer(AbstractTokenizer):
special_tokens=special_tokens, max_len=None) special_tokens=special_tokens, max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>'] self.eod_id = self.tokenizer.encoder['<|endoftext|>']
if len(special_tokens) > 0: if special_tokens is not None and len(special_tokens) > 0:
if "[SEP]" in special_tokens: if "[SEP]" in special_tokens:
self.sep_id = self.tokenizer.special_tokens['[SEP]'] self.sep_id = self.tokenizer.special_tokens['[SEP]']
if "[CTRL]" in special_tokens: if "[CTRL]" in special_tokens:
......
...@@ -180,8 +180,12 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -180,8 +180,12 @@ def pretrain(train_valid_test_dataset_provider,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
if e >= 8 and e <= 13 and args.save and iteration != 0: # if args.train_module == "dialog":
save_checkpoint(iteration, model, optimizer, lr_scheduler) # if (e+1) >= 6 and (e+1) <= 15 and args.save and iteration != 0:
# save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.train_module == "control":
if (e+1) >= 5 and (e+1) <= 9 and args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
...@@ -845,7 +849,7 @@ def build_train_valid_test_data_iterators( ...@@ -845,7 +849,7 @@ def build_train_valid_test_data_iterators(
print_rank_0(' validation: {}'.format(valid_size)) print_rank_0(' validation: {}'.format(valid_size))
print_rank_0(' test: {}'.format(test_size)) print_rank_0(' test: {}'.format(test_size))
batch_size = args.micro_batch_size * args.data_parallel_size batch_size = args.global_batch_size
args.train_iters = train_size // batch_size + 1 args.train_iters = train_size // batch_size + 1
args.eval_iters = valid_size // batch_size + 1 args.eval_iters = valid_size // batch_size + 1
args.test_iters = test_size // batch_size + 1 args.test_iters = test_size // batch_size + 1
......
"""Build Dataset for Controllable Coversational Model"""
import os
import torch
import numpy as np
from megatron import get_tokenizer
from megatron import print_rank_0
def read_data(tokenizer, data_path, train_module):
"""read and tokenize dialog data"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
length_split = len(splits)
assert length_split == 2 or length_split == 3 or length_split == 4
if train_module == "dialog":
# if length_split == 2:
# continue
dialog_context = splits[0]
if length_split > 2:
ctrl_sent = splits[-2]
response = splits[-1]
# only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
# input_ids
for idx, turn in enumerate(turns):
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
if idx == 0:
input_ids = tokenizer.tokenize(turn)
else:
# input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(turn))
if length_split > 2:
# when there is control sentence, add it into the input_ids
# input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids.extend(tokenizer.tokenize("( " + ctrl_sent + " ) ."))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
if length_split == 2:
continue
dialog_context = splits[0]
ctrl_sent = splits[-2]
ctrl_code = splits[1] if length_split == 4 else None
turns = dialog_context.split(" [SEP] ")
# last_turn = turns[-1]
# turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
input_ids = []
if ctrl_code:
ctrl_code_list = ctrl_code.split(" [CTRL] ")
for code in ctrl_code_list:
input_ids.extend(tokenizer.tokenize("( " + code + " )"))
turns = turns[-3:]
for turn in turns:
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
input_ids.extend(tokenizer.tokenize(turn))
# output_ids
outputs = ctrl_sent
output_ids = tokenizer.tokenize(outputs)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
np.random.shuffle(data)
return data
class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, sep_id, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len = max_seq_len
self.sep_id = sep_id
self.pad_id = pad_id
self.eod_id = eod_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
text = input_ids + output_ids + [self.eod_id]
loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len]
else:
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max_seq_len, seed):
"""Build train, valid, and test datasets."""
dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer = get_tokenizer()
train_data_list = read_data(tokenizer, train_data_path, train_module)
valid_data_list = read_data(tokenizer, valid_data_path, train_module)
test_data_list = read_data(tokenizer, test_data_path, train_module)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid, and test datasets
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset, test_dataset
"""Controllable Dialogue Finetuning"""
import torch
from functools import partial
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model import GPTModel
from megatron.training import evaluate_and_print_results
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import finetune
from tasks.dialctrl.data import build_train_valid_test_datasets
from tasks.dialctrl.utils import get_ltor_attention_masks_and_position_ids
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def train_valid_datasets_provider():
"""Build train, valid, and test datasets for dialog/control module"""
args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
train_ds, valid_ds, _ = build_train_valid_test_datasets(
data_folder=args.data_folder,
dataset_name=args.dataset_name,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
args.eval_interval = len(train_ds) // args.global_batch_size
print_rank_0(' > evaluation interval: %d' % args.eval_interval)
return train_ds, valid_ds
def process_batch(batch):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text', 'loss_mask']
datatype = torch.int64
data_b = mpu.broadcast_data(keys, batch, datatype)
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
loss_mask = data_b['loss_mask'].float()
# Get the attention_mask and postition ids.
attention_mask, position_ids = \
get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(batch, model):
"""Forward step."""
args = get_args()
timers = get_timers()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, labels, loss_mask, attention_mask, position_ids = process_batch(batch_)
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def main():
finetune(train_valid_datasets_provider, model_provider, \
forward_step=forward_step)
import torch
from megatron import print_rank_0
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# # reset attentino mask and position ids
# # Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
# eod_index = eod_index.clone()
# # Loop through EOD indecies:
# prev_index = 0
# for j in range(eod_index.size()[0]):
# i = eod_index[j]
# # Mask attention loss.
# attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# # Reset positions.
# position_ids[b, (i + 1):] -= (i + 1 - prev_index)
# prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, position_ids
\ No newline at end of file
...@@ -84,6 +84,18 @@ def get_tasks_args(parser): ...@@ -84,6 +84,18 @@ def get_tasks_args(parser):
help='Av.rank validation: how many other negatives to' help='Av.rank validation: how many other negatives to'
' take from each question pool') ' take from each question pool')
# finetune for controllable dialogue
group.add_argument('--train-module', type=str, default="",
help='either control module or dialogue model (control or dialog)')
group.add_argument('--data-folder', type=str, default="",
help='data folder (path of the data folder)')
group.add_argument('--dataset-name', type=str, default="",
help='dataset name (e.g., wizard_of_wikipedia)')
group.add_argument('--max-seq-len', type=int, default=1024,
help='maximum sequence length')
group.add_argument('--spec-toks', type=str, default="[SEP],[CTRL],[PAD]",
help='additional special tokens')
return parser return parser
...@@ -108,6 +120,8 @@ if __name__ == '__main__': ...@@ -108,6 +120,8 @@ if __name__ == '__main__':
from orqa.evaluate_orqa import main from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']: elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main from orqa.supervised.finetune import main
elif args.task == 'dialctrl':
from dialctrl.finetune import main
else: else:
raise NotImplementedError('Task {} is not implemented.'.format( raise NotImplementedError('Task {} is not implemented.'.format(
args.task)) args.task))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment