Commit f86bb671 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

checked and bert, gpt, and albert albert run

parent d6485684
...@@ -278,7 +278,7 @@ def add_data_args(parser): ...@@ -278,7 +278,7 @@ def add_data_args(parser):
help='path(s) to the validation data.') help='path(s) to the validation data.')
group.add_argument('--test-data', nargs='*', default=None, group.add_argument('--test-data', nargs='*', default=None,
help='path(s) to the testing data.') help='path(s) to the testing data.')
group.add_argument('--data-path', type=str, default=None, group.add_argument('--data-path', nargs='+', default=None,
help='path to combined dataset to split') help='path to combined dataset to split')
group.add_argument('--split', default='1000,1,1', group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,' help='comma-separated list of proportions for training,'
......
...@@ -131,6 +131,8 @@ def make_loaders(args): ...@@ -131,6 +131,8 @@ def make_loaders(args):
if eval_seq_length is not None and eval_seq_length < 0: if eval_seq_length is not None and eval_seq_length < 0:
eval_seq_length = eval_seq_length * world_size eval_seq_length = eval_seq_length * world_size
split = get_split(args) split = get_split(args)
if args.data_path is not None:
args.train_data = args.data_path
data_set_args = { data_set_args = {
'path': args.train_data, 'path': args.train_data,
'seq_length': seq_length, 'seq_length': seq_length,
......
...@@ -57,7 +57,7 @@ def make_gpt2_dataloaders(args): ...@@ -57,7 +57,7 @@ def make_gpt2_dataloaders(args):
pin_memory=True) pin_memory=True)
train = make_data_loader_(args.train_data) train = make_data_loader_(args.train_data)
valid = make_data_loader_(args.val_data) valid = make_data_loader_(args.valid_data)
test = make_data_loader_(args.test_data) test = make_data_loader_(args.test_data)
args.do_train = False args.do_train = False
......
...@@ -143,9 +143,10 @@ def get_train_val_test_data(args): ...@@ -143,9 +143,10 @@ def get_train_val_test_data(args):
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
assert len(args.data_path) == 1
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab, vocab_file=args.vocab,
data_prefix=args.data_path, data_prefix=args.data_path[0],
data_impl=args.data_impl, data_impl=args.data_impl,
splits_string=args.split, splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
......
...@@ -113,9 +113,15 @@ def get_train_val_test_data(args): ...@@ -113,9 +113,15 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
if args.data_loader == 'numpy': if args.data_loader == 'numpy':
assert len(args.train_data) == 1
args.train_data = args.train_data[0]
assert len(args.valid_data) == 1
args.valid_data = args.valid_data[0]
assert len(args.test_data) == 1
args.test_data = args.test_data[0]
(train_data, val_data, test_data), num_tokens, \ (train_data, val_data, test_data), num_tokens, \
eod_token = make_gpt2_dataloaders(args) eod_token = make_gpt2_dataloaders(args)
elif args.data_loader == 'raw' or args.data_loader == 'lazy' elif args.data_loader == 'raw' or args.data_loader == 'lazy':
data_config = configure_data() data_config = configure_data()
data_config.set_defaults(data_set_type='GPT2', transpose=False) data_config.set_defaults(data_set_type='GPT2', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply( (train_data, val_data, test_data), tokenizer = data_config.apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment