Commit 0601702a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

zero worker seems to be working

parent f6a6811f
......@@ -8,6 +8,7 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import mpu
from megatron.data import helpers
from megatron.data import FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample
......@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
splits_string, train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup):
# Tokenizer is the same
tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
tokenizer.vocab_size()))
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
dataset = AlbertDataset(
name=name,
indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
masked_lm_prob=masked_lm_prob,
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
class AlbertDataset(Dataset):
def __init__(self, vocab_file, data_prefix, data_impl, skip_warmup,
num_epochs, max_num_samples, masked_lm_prob, max_seq_length,
short_seq_prob, seed):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
# Indexed dataset.
self.indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
# Tokenizer and dataset.
self.tokenizer = tokenizer
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
......@@ -39,7 +115,8 @@ class AlbertDataset(Dataset):
max_num_samples,
self.max_seq_length,
short_seq_prob,
self.seed)
self.seed,
self.name)
# Vocab stuff.
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
......@@ -48,7 +125,6 @@ class AlbertDataset(Dataset):
self.sep_id = self.tokenizer.vocab['[SEP]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self.pad_id = self.tokenizer.vocab['[PAD]']
exit()
def num_tokens(self):
......@@ -68,9 +144,11 @@ class AlbertDataset(Dataset):
sample = []
for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index])
'''
for s in sample:
if len(s) > 1000:
print(self.tokenizer.convert_ids_to_tokens(s))
'''
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
......@@ -80,25 +158,63 @@ class AlbertDataset(Dataset):
self.masked_lm_prob, rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
print_rank_0("> Reading dataset index ...")
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
print_rank_0("> Finished creating indexed dataset in {:4f} "
"seconds".format(time.time() - start_time))
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split/splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping_(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed):
seed,
name):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
......@@ -109,9 +225,11 @@ def get_samples_mapping_(indexed_dataset,
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_indexmap'
indexmap_filename += '_{}ep'.format(num_epochs)
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
......@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset,
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print('WARNING: could not find index map file {}, building '
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
......@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset,
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
......@@ -138,21 +259,30 @@ def get_samples_mapping_(indexed_dataset,
short_seq_prob,
seed,
verbose)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0('> elasped time to build and save samples mapping '
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
torch.distributed.barrier()
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0('> loading indexed mapping from {}'.format(
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
......
......@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
and sequence-length is the target sequence length.
*/
if (verbose) {
cout << " > using " << docs_.shape(0) - 1 <<
" documents with " << sizes_.shape(0) << " sentences ..." <<
endl << std::flush;
}
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
......@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
if (docs[docs.shape(0) - 1] != sizes.shape(0)) {
cout << "document values is not consistent with length of sizes: " <<
docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl;
throw std::length_error("docs and sizes");
// For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " short sequence probability: " << short_seq_prob <<
endl << std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and it's length (1D).
......@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " > reached " << max_num_samples << " samples after "
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
......@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
if (!second) {
if (verbose) {
cout << " > number of empty documents: " << empty_docs <<
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " > number of documents with one sentence: " <<
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " > will create mapping for " << map_index <<
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
......@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
swap(maps[i0 + 2], maps[j0 + 2]);
}
if (verbose) {
cout << "> done building the mapping." << endl;
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
......@@ -239,34 +249,20 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int seed,
const bool verbose) {
if (verbose) {
cout << "> building sample map using: " << endl << std::flush;
cout << " number of epochs: " << num_epochs << endl
<< std::flush;
cout << " maximum number of samples: " << max_num_samples << endl
<< std::flush;
cout << " maximum sequence length: " << max_seq_length << endl
<< std::flush;
cout << " short sequence probability: " << short_seq_prob << endl
<< std::flush;
cout << " seed: " << seed << endl
<< std::flush;
}
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " > using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
} else {
if (verbose) {
cout << " > using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
}
}
......
......@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
offset = stream.tell()
if not skip_warmup:
print_rank_0("> Warming up index mmap file...")
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0("> Reading sizes...")
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
print_rank_0("> Reading pointers...")
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print_rank_0("> Reading document index...")
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
......@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0("> Warming up data mmap file...")
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0("> Creating numpy buffer of mmap...")
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print_rank_0("> Creating memory view of numpy buffer...")
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0("> Done")
def __del__(self):
self._bin_buffer_mmap._mmap.close()
......@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
......@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
......
......@@ -13,43 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset to split one large one into multiple smaller datasets"""
import torch
import numpy as np
def should_split(split):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split)/sum(split) != 1.
def get_train_valid_test_split(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
def get_split(args):
"""
Get dataset splits from comma separated string list
"""
splits = []
if args.split.find(',') != -1:
splits = [float(s) for s in args.split.split(',')]
elif args.split.find('/') != -1:
splits = [float(s) for s in args.split.split('/')]
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(args.split)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1-split_total)
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
if args.valid_data is not None:
splits[1] = 0.
if args.test_data is not None:
splits[2] = 0.
final_sum = sum(splits)
return [s/final_sum for s in splits]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split/splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
return splits_index
class SplitDataset(torch.utils.data.Dataset):
"""
......
......@@ -13,21 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT"""
"""Pretrain ALBERT"""
import torch
import torch.nn.functional as F
from configure_data import configure_data
from megatron import mpu
from megatron.model import BertModel
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.data import AlbertDataset, split_dataset
from megatron.data.albert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
def model_provider(args):
"""Build the model."""
......@@ -109,94 +109,98 @@ def forward_step(data_iterator, model, args, timers):
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, val_data, test_data) = (None, None, None)
(train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
if args.data_loader == None:
print_rank_0('> building train, validation, and test datasets '
'for ALBERT ...')
if args.data_loader is None:
args.data_loader = 'binary'
if args.data_loader == 'binary':
if not args.max_num_samples:
args.max_num_samples = (args.train_iters + 2 * args.eval_iters) * args.batch_size
if not args.data_path:
print("Albert currently only supports a unified dataset specified with --data-path")
exit(1)
print_rank_0("Creating AlbertDataset...")
full_data = AlbertDataset(
vocab_file=args.vocab,
data_prefix=args.data_path,
data_impl=args.data_impl,
skip_warmup=args.skip_mmap_warmup,
num_epochs=args.data_epochs,
max_num_samples=args.max_num_samples,
masked_lm_prob=args.mask_prob,
max_seq_length=args.seq_length,
short_seq_prob=args.short_seq_prob,
seed=args.seed)
print_rank_0("Finished creating AlbertDataset...")
split = split_dataset.get_split(args)
if split_dataset.should_split(split):
train_ds, val_ds, test_ds = split_dataset.split_ds(full_data, split, args.shuffle)
else:
train_ds = full_data
num_tokens = train_ds.num_tokens()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(val_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens for now
int(do_train),
int(do_valid),
int(do_test)])
else:
print("Unsupported data loader for BERT.")
if args.data_loader != 'binary':
print('Unsupported {} data loader for ALBERT.'.format(
args.data_loader))
exit(1)
if not args.data_path:
print('ALBERT only supports a unified dataset specified '
'with --data-path')
exit(1)
data_parallel_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [args.train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab,
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=args.skip_mmap_warmup)
print_rank_0("> finished creating ALBERT datasets ...")
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=data_parallel_rank,
world_size=data_parallel_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(valid_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
num_tokens = vocab_size_with_padding(train_ds.num_tokens(), args)
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens
int(do_train),
int(do_valid),
int(do_test)])
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item()
num_type_tokens = token_counts[1].item()
args.vocab_size = token_counts[0].item()
args.tokentype_size = token_counts[1].item()
args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
args.vocab_size = num_tokens
args.tokentype_size = num_type_tokens
return train_data, val_data, test_data
return train_data, valid_data, test_data
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment