Commit 63e59949 authored by Anmol Gupta's avatar Anmol Gupta
Browse files

support for separate dataset files for train, valid and test

parent 981c3dfa
......@@ -838,7 +838,15 @@ def _add_data_args(parser):
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
......@@ -854,11 +862,7 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
......
......@@ -70,6 +70,8 @@ def get_datasets_weights_and_num_samples(data_prefix,
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
else:
# Used when separate dataset files are provided for train,
# valid and test
datasets_train_valid_test_num_samples = [
int(math.ceil(train_valid_test_num_samples * weight * 1.005))
for weight in weights]
......
......@@ -28,11 +28,11 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, train_data_prefix,
valid_data_prefix, test_data_prefix,
data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
def build_train_valid_test_datasets(data_prefix, data_impl,
splits_string, train_valid_test_num_samples,
seq_length, seed, skip_warmup,
train_data_prefix=None, valid_data_prefix=None,
test_data_prefix=None,):
"""Build train, valid, and test datasets."""
if data_prefix:
......
......@@ -108,15 +108,15 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,)
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment