Commit fac6718a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'blendable_dataset' into 'megatron_sampler'

Blendable dataset

See merge request ADLR/megatron-lm!178
parents cebd3b8b 1eda0a17
......@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def _print_args(args):
"""Print arguments."""
if args.rank == 0:
print('-------------------- arguments --------------------', flush=True)
print('------------------------ arguments ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (32 - len(arg))
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('---------------- end of arguments ----------------', flush=True)
print('-------------------- end of arguments ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg):
......@@ -278,7 +280,7 @@ def _add_learning_rate_args(parser):
'and initial warmup, the learing rate at each '
'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'exponential'],
choices=['constant', 'linear', 'cosine'],
help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
......@@ -400,8 +402,11 @@ def _add_validation_args(parser):
def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', type=str, default=None,
help='Path to combined dataset to split.')
group.add_argument('--data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = 0
for dataset in self.datasets:
self.size += len(dataset)
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
# Build indecies.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
tmp = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group())
from megatron.data import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print_rank_0('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
def __len__(self):
return self.size
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]
......@@ -18,11 +18,13 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import math
import time
import collections
import numpy as np
from megatron import get_args, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert'
......@@ -31,6 +33,38 @@ DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0]*num_datasets
prefixes = [0]*num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2*i])
prefixes[i] = (data_prefix[2*i+1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
......@@ -360,6 +394,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type)
# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
......
......@@ -22,6 +22,8 @@ import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......@@ -31,6 +33,46 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
train_datasets.append(train_ds)
valid_datasets.append(valid_ds)
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
......
......@@ -33,6 +33,69 @@ using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
py::array_t<int64_t>& dataset_sample_index,
const py::array_t<double>& weights,
const int32_t num_datasets,
const int64_t size, const bool verbose) {
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if (verbose) {
std::cout << "> building indices for blendable datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
int64_t current_samples[num_datasets];
for(int64_t i = 0; i < num_datasets; ++i) {
current_samples[i] = 0;
}
// For each sample:
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(current_samples[dataset_idx]);
if (error > max_error) {
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
// Update the total samples.
current_samples[max_error_index] += 1;
}
// print info
if (verbose) {
std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
......@@ -640,4 +703,5 @@ PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_blending_indices", &build_blending_indices);
}
......@@ -6,7 +6,6 @@ import torch
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron.data.samplers import DistributedBatchSampler
from megatron import get_args, get_tokenizer, print_rank_0, mpu
......@@ -23,6 +22,8 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
assert False, 'DistributedBatchSampler deprecated, change the implementation'
from megatron.data.samplers import DistributedBatchSampler
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
......
......@@ -19,77 +19,101 @@ import math
from megatron import print_rank_0
class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self, optimizer, start_lr,
warmup_iter, total_iters,
decay_style, last_iter, min_lr=0.0,
def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps,
decay_style, num_steps,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
# Class values.
self.optimizer = optimizer
self.start_lr = start_lr
self.max_lr = float(max_lr)
self.min_lr = min_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter
self.end_iter = total_iters
assert self.end_iter > 0
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps
self.num_steps = num_steps
self.decay_steps = decay_steps
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(self.num_iters)
self.step(step_num=self.num_steps)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
# Warmup.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * num_iters_ / self.warmup_iter
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == 'linear':
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter
coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine':
lr = self.start_lr / 2.0 * (math.cos(
math.pi * num_iters_ / self.end_iter) + 1)
elif self.decay_style == 'exponential':
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
lr = self.start_lr
return max(lr, self.min_lr)
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
return self.min_lr + coeff * delta_lr
def step(self, step_num=None):
def step(self, increment=1, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
step_num = self.num_iters + 1
self.num_iters = step_num
step_num = self.num_steps + increment
self.num_steps = step_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
state_dict = {
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps,
'num_steps': self.num_steps,
'decay_style': self.decay_style,
'end_iter': self.end_iter,
'decay_steps': self.decay_steps,
'min_lr': self.min_lr
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
......@@ -104,20 +128,39 @@ class AnnealingLR(object):
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
'learning rate')
if 'start_lr' in sd:
max_lr_ = sd['start_lr']
else:
max_lr_ = sd['max_lr']
self.max_lr = self._check_and_set(self.max_lr, max_lr_,
'learning rate')
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'],
'warmup iterations')
self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'],
'total number of iterations')
if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter']
else:
warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps,
warmup_steps_,
'warmup iterations')
if 'end_iter' in sd:
decay_steps_ = sd['end_iter']
else:
decay_steps_ = sd['decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'],
'decay style')
self.num_iters = sd['num_iters']
self.step(self.num_iters)
if 'num_iters' in sd:
self.num_steps = sd['num_iters']
else:
self.num_steps = sd['num_steps']
self.step(step_num=self.num_steps)
......@@ -194,12 +194,12 @@ def get_learning_rate_scheduler(optimizer):
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
total_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
max_lr=args.lr,
min_lr=args.min_lr,
warmup_steps=warmup_iter,
decay_steps=num_iters,
decay_style=args.lr_decay_style,
num_steps=init_step,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment