Commit 65290033 authored by mohammad's avatar mohammad
Browse files

implemented blending datasets

parent 9a0808c9
......@@ -400,7 +400,7 @@ def _add_validation_args(parser):
def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', type=str, default=None,
group.add_argument('--data-path', nargs='*', default=None,
help='Path to combined dataset to split.')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
......
......@@ -18,11 +18,13 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import math
import time
import collections
import numpy as np
from megatron import get_args, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert'
......@@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0]*num_datasets
prefixes = [0]*num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2*i])
prefixes[i] = (data_prefix[2*i+1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
......@@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type)
# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
......
......@@ -22,6 +22,8 @@ import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......@@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
train_datasets.append(train_ds)
valid_datasets.append(valid_ds)
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
......
......@@ -6,7 +6,6 @@ import torch
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron.data.samplers import DistributedBatchSampler
from megatron import get_args, get_tokenizer, print_rank_0, mpu
......@@ -23,6 +22,8 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
assert False, 'DistributedBatchSampler deprecated, change the implementation'
from megatron.data.samplers import DistributedBatchSampler
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment