Commit 981c3dfa authored by ANMOL GUPTA's avatar ANMOL GUPTA
Browse files

support separate datasets for train, valid and test

parent d63c2541
...@@ -839,6 +839,21 @@ def _add_data_args(parser): ...@@ -839,6 +839,21 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight ' 'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...') 'dataset2-path ...')
group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--valid-data-path', nargs='*', default=None,
help='Path to the validation dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--test-data-path', nargs='*', default=None,
help='Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
......
...@@ -63,12 +63,16 @@ def get_datasets_weights_and_num_samples(data_prefix, ...@@ -63,12 +63,16 @@ def get_datasets_weights_and_num_samples(data_prefix,
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does # Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have # not uniformly distribute the number of samples, we still have
# samples left to feed to the network. # samples left to feed to the network.
if isinstance(train_valid_test_num_samples, list):
datasets_train_valid_test_num_samples = [] datasets_train_valid_test_num_samples = []
for weight in weights: for weight in weights:
datasets_train_valid_test_num_samples.append( datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) [int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples]) for val in train_valid_test_num_samples])
else:
datasets_train_valid_test_num_samples = [
int(math.ceil(train_valid_test_num_samples * weight * 1.005))
for weight in weights]
return prefixes, weights, datasets_train_valid_test_num_samples return prefixes, weights, datasets_train_valid_test_num_samples
......
...@@ -28,11 +28,15 @@ from megatron.data.dataset_utils import get_train_valid_test_split_ ...@@ -28,11 +28,15 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, train_data_prefix,
valid_data_prefix, test_data_prefix,
data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
seq_length, seed, skip_warmup): seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
if data_prefix:
print_rank_0("Single data path provided for train, valid & test")
# Single dataset. # Single dataset.
if len(data_prefix) == 1: if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0], return _build_train_valid_test_datasets(data_prefix[0],
...@@ -75,6 +79,82 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -75,6 +79,82 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
else:
print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.")
assert (train_data_prefix is not None)
train_dataset, valid_dataset, test_dataset = None, None, None
# Single dataset.
train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0], seq_length, seed,
skip_warmup)
if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1], seq_length, seed,
False)
if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2], seq_length, seed,
False)
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup):
dataset = None
if len(data_prefix) == 1:
dataset = _build_dataset(dataset_name,
data_prefix[0], data_impl,
num_samples, seq_length,
seed, skip_warmup)
else:
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, num_samples)
prefixes, weights, dataset_num_samples = output
# Build individual datasets.
datasets = []
for i in range(len(prefixes)):
ds = _build_dataset(dataset_name, prefixes[i],
data_impl, dataset_num_samples[i],
seq_length, seed, skip_warmup)
if ds:
datasets.append(ds)
if datasets:
dataset = BlendableDataset(datasets, weights)
return dataset
def _build_dataset(dataset_name, data_prefix, data_impl,
num_samples, seq_length, seed, skip_warmup):
"""
Build dataset. This method is called when individual
train, valid, test datasets are provided
"""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
print_rank_0(' {}:'.format(dataset_name))
print_rank_0(' document indices in [0, {}) total of {} '
'documents'.format(total_num_of_documents, total_num_of_documents))
documents = np.arange(start=0, stop=total_num_of_documents,
step=1, dtype=np.int32)
dataset = GPTDataset(dataset_name, data_prefix,
documents, indexed_dataset,
num_samples, seq_length, seed)
return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
......
...@@ -108,6 +108,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -108,6 +108,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
'for GPT ...') 'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path, data_prefix=args.data_path,
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_impl=args.data_impl, data_impl=args.data_impl,
splits_string=args.split, splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment