Commit 63e59949 authored by Anmol Gupta's avatar Anmol Gupta
Browse files

support for separate dataset files for train, valid and test

parent 981c3dfa
...@@ -838,7 +838,15 @@ def _add_data_args(parser): ...@@ -838,7 +838,15 @@ def _add_data_args(parser):
help='Path to the training dataset. Accepted format:' help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight ' 'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...') 'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--train-data-path', nargs='*', default=None, group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:' help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
...@@ -854,11 +862,7 @@ def _add_data_args(parser): ...@@ -854,11 +862,7 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight ' 'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...') 'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--vocab-file', type=str, default=None, group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
......
...@@ -70,6 +70,8 @@ def get_datasets_weights_and_num_samples(data_prefix, ...@@ -70,6 +70,8 @@ def get_datasets_weights_and_num_samples(data_prefix,
[int(math.ceil(val * weight * 1.005)) [int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples]) for val in train_valid_test_num_samples])
else: else:
# Used when separate dataset files are provided for train,
# valid and test
datasets_train_valid_test_num_samples = [ datasets_train_valid_test_num_samples = [
int(math.ceil(train_valid_test_num_samples * weight * 1.005)) int(math.ceil(train_valid_test_num_samples * weight * 1.005))
for weight in weights] for weight in weights]
......
...@@ -28,11 +28,11 @@ from megatron.data.dataset_utils import get_train_valid_test_split_ ...@@ -28,11 +28,11 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, train_data_prefix, def build_train_valid_test_datasets(data_prefix, data_impl,
valid_data_prefix, test_data_prefix, splits_string, train_valid_test_num_samples,
data_impl, splits_string, seq_length, seed, skip_warmup,
train_valid_test_num_samples, train_data_prefix=None, valid_data_prefix=None,
seq_length, seed, skip_warmup): test_data_prefix=None,):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
if data_prefix: if data_prefix:
......
...@@ -108,15 +108,15 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -108,15 +108,15 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
'for GPT ...') 'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path, data_prefix=args.data_path,
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_impl=args.data_impl, data_impl=args.data_impl,
splits_string=args.split, splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length, seq_length=args.seq_length,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,)
print_rank_0("> finished creating GPT datasets ...") print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment