Commit 981c3dfa authored by ANMOL GUPTA's avatar ANMOL GUPTA
Browse files

support separate datasets for train, valid and test

parent d63c2541
...@@ -839,6 +839,21 @@ def _add_data_args(parser): ...@@ -839,6 +839,21 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight ' 'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...') 'dataset2-path ...')
group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--valid-data-path', nargs='*', default=None,
help='Path to the validation dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--test-data-path', nargs='*', default=None,
help='Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
......
...@@ -63,12 +63,16 @@ def get_datasets_weights_and_num_samples(data_prefix, ...@@ -63,12 +63,16 @@ def get_datasets_weights_and_num_samples(data_prefix,
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does # Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have # not uniformly distribute the number of samples, we still have
# samples left to feed to the network. # samples left to feed to the network.
datasets_train_valid_test_num_samples = [] if isinstance(train_valid_test_num_samples, list):
for weight in weights: datasets_train_valid_test_num_samples = []
datasets_train_valid_test_num_samples.append( for weight in weights:
[int(math.ceil(val * weight * 1.005)) datasets_train_valid_test_num_samples.append(
for val in train_valid_test_num_samples]) [int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
else:
datasets_train_valid_test_num_samples = [
int(math.ceil(train_valid_test_num_samples * weight * 1.005))
for weight in weights]
return prefixes, weights, datasets_train_valid_test_num_samples return prefixes, weights, datasets_train_valid_test_num_samples
......
...@@ -28,53 +28,133 @@ from megatron.data.dataset_utils import get_train_valid_test_split_ ...@@ -28,53 +28,133 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, train_data_prefix,
valid_data_prefix, test_data_prefix,
data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
seq_length, seed, skip_warmup): seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
# Single dataset. if data_prefix:
print_rank_0("Single data path provided for train, valid & test")
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
else:
print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.")
assert (train_data_prefix is not None)
train_dataset, valid_dataset, test_dataset = None, None, None
# Single dataset.
train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0], seq_length, seed,
skip_warmup)
if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1], seq_length, seed,
False)
if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2], seq_length, seed,
False)
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup):
dataset = None
if len(data_prefix) == 1: if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0], dataset = _build_dataset(dataset_name,
data_impl, splits_string, data_prefix[0], data_impl,
train_valid_test_num_samples, num_samples, seq_length,
seq_length, seed, skip_warmup) seed, skip_warmup)
else:
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, output = get_datasets_weights_and_num_samples(data_prefix, num_samples)
train_valid_test_num_samples) prefixes, weights, dataset_num_samples = output
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
# Build individual datasets. datasets = []
train_datasets = [] for i in range(len(prefixes)):
valid_datasets = [] ds = _build_dataset(dataset_name, prefixes[i],
test_datasets = [] data_impl, dataset_num_samples[i],
for i in range(len(prefixes)): seq_length, seed, skip_warmup)
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( if ds:
prefixes[i], data_impl, splits_string, datasets.append(ds)
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup) if datasets:
if train_ds: dataset = BlendableDataset(datasets, weights)
train_datasets.append(train_ds)
if valid_ds: return dataset
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds) def _build_dataset(dataset_name, data_prefix, data_impl,
num_samples, seq_length, seed, skip_warmup):
# Blend. """
blending_train_dataset = None Build dataset. This method is called when individual
if train_datasets: train, valid, test datasets are provided
blending_train_dataset = BlendableDataset(train_datasets, weights) """
blending_valid_dataset = None
if valid_datasets: # Indexed dataset.
blending_valid_dataset = BlendableDataset(valid_datasets, weights) indexed_dataset = get_indexed_dataset_(data_prefix,
blending_test_dataset = None data_impl,
if test_datasets: skip_warmup)
blending_test_dataset = BlendableDataset(test_datasets, weights)
total_num_of_documents = indexed_dataset.sizes.shape[0]
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) print_rank_0(' {}:'.format(dataset_name))
print_rank_0(' document indices in [0, {}) total of {} '
'documents'.format(total_num_of_documents, total_num_of_documents))
documents = np.arange(start=0, stop=total_num_of_documents,
step=1, dtype=np.int32)
dataset = GPTDataset(dataset_name, data_prefix,
documents, indexed_dataset,
num_samples, seq_length, seed)
return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
......
...@@ -108,6 +108,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -108,6 +108,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
'for GPT ...') 'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path, data_prefix=args.data_path,
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,
data_impl=args.data_impl, data_impl=args.data_impl,
splits_string=args.split, splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment