Commit 4676493c authored by zihanl's avatar zihanl
Browse files

fix params

parent f71a61fc
......@@ -36,21 +36,16 @@ def train_valid_datasets_provider():
"""Build train, valid, and test datasets for dialog/control module"""
args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.module)
train_ds, valid_ds = build_train_valid_datasets(
train_data_path=args.train_data_path,
valid_data_path=args.test_data_path,
train_module=args.train_module,
module=args.module,
max_seq_len=args.max_seq_len,
seed=args.seed,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
remove_ctrl_sent=args.remove_ctrl_sent)
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
print_rank_0("> finished creating datasets for %s module ..." % args.module)
print_rank_0('> Train size: %d' % len(train_ds))
print_rank_0('> Validation size: %d' % len(valid_ds))
......
......@@ -142,9 +142,9 @@ if __name__ == '__main__':
from orqa.supervised.finetune import main
elif args.task == 'knwl-dialo-prompt':
from knwl_dialo.prompt import main
elif args.task == ['knwl-dialo-finetune', 'knwl-dialo-gen']:
elif args.task in ['knwl-dialo-finetune', 'knwl-dialo-gen']:
from knwl_dialo.finetune import main
elif args.task in ['knwl-dialo-eval-ppl', 'knwl-dialo-eval-f1']:
elif args.task == 'knwl-dialo-eval-f1':
from knwl_dialo.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment