Commit 1a1cfcff authored by Jared Casper's avatar Jared Casper
Browse files

Plumbing for new albert dataset, including mods to arguments for data loaders.

parent 0058b1a5
...@@ -267,23 +267,52 @@ def add_data_args(parser): ...@@ -267,23 +267,52 @@ def add_data_args(parser):
group.add_argument('--shuffle', action='store_true', group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic ' help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.') 'based on seed and current epoch.')
group.add_argument('--data-loader', type=str, default=None,
choices=['raw', 'lazy', 'tfrecords', 'numpy', 'binary'],
help='Which data loader to use. Default varies by model.')
group.add_argument('--train-data', nargs='+', default=None, group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated filenames or corpora names ' help='Whitespace separated paths or corpora names '
'for training.') 'for training.')
group.add_argument('--valid-data', nargs='*', default=None,
help='path(s) to the validation data.')
group.add_argument('--test-data', nargs='*', default=None,
help='path(s) to the testing data.')
group.add_argument('--data-path', type=str, default=None,
help='path to combined dataset to split')
group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,'
' validation, and test split')
group.add_argument('--use-npy-data-loader', action='store_true', group.add_argument('--seq-length', type=int, default=512,
help='Use the numpy data loader. If set, then' help="Maximum sequence length to process")
'train-data-path, val-data-path, and test-data-path' group.add_argument('--max-preds-per-seq', type=int, default=None,
'should also be provided.') help='Maximum number of predictions to use per sequence.'
group.add_argument('--train-data-path', type=str, default='', 'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
help='path to the training data') 'MUST BE SPECIFIED IF `--data-loader tfrecords`.')
group.add_argument('--val-data-path', type=str, default='',
help='path to the validation data') # arguments for binary data loader
group.add_argument('--test-data-path', type=str, default='', parser.add_argument('--vocab', type=str, default='vocab.txt',
help='path to the test data') help='path to vocab file')
parser.add_argument('--data-impl', type=str, default='infer',
help='implementation of indexed datasets',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for, defaults to total iters * batch-size.')
parser.add_argument('--data-epochs', type=int, default=None,
help='Number of epochs to plan for, defaults to using --max-num-samples')
parser.add_argument('--mask-prob', default=0.15, type=float,
help='probability of replacing a token with mask')
parser.add_argument('--short-seq-prob', default=0.1, type=float,
help='probability of producing a short sequence')
parser.add_argument('--skip-mmap-warmup', action='store_true',
help='skip warming up mmap files')
# arguments for numpy data loader
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='the filename containing all the shards sizes') help='the filename containing all the shards sizes for numpy data loader')
# arguments for raw/tfrecords data loader
group.add_argument('--delim', default=',', group.add_argument('--delim', default=',',
help='delimiter used to parse csv data files') help='delimiter used to parse csv data files')
group.add_argument('--text-key', default='sentence', group.add_argument('--text-key', default='sentence',
...@@ -291,16 +320,6 @@ def add_data_args(parser): ...@@ -291,16 +320,6 @@ def add_data_args(parser):
group.add_argument('--eval-text-key', default=None, group.add_argument('--eval-text-key', default=None,
help='key to use to extract text from ' help='key to use to extract text from '
'json/csv evaluation datasets') 'json/csv evaluation datasets')
group.add_argument('--valid-data', nargs='*', default=None,
help="""Filename for validation data.""")
group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,'
' validation, and test split')
group.add_argument('--test-data', nargs='*', default=None,
help="""Filename for testing""")
group.add_argument('--lazy-loader', action='store_true',
help='whether to lazy read the data set')
group.add_argument('--loose-json', action='store_true', group.add_argument('--loose-json', action='store_true',
help='Use loose json (one json-formatted string per ' help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one ' 'newline), instead of tight json (data file is one '
...@@ -308,6 +327,7 @@ def add_data_args(parser): ...@@ -308,6 +327,7 @@ def add_data_args(parser):
group.add_argument('--presplit-sentences', action='store_true', group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where ' help='Dataset content consists of documents where '
'each document consists of newline separated sentences') 'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2, group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""") help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str, group.add_argument('--tokenizer-model-type', type=str,
...@@ -328,16 +348,6 @@ def add_data_args(parser): ...@@ -328,16 +348,6 @@ def add_data_args(parser):
help='what type of tokenizer to use') help='what type of tokenizer to use')
group.add_argument("--cache-dir", default=None, type=str, group.add_argument("--cache-dir", default=None, type=str,
help="Where to store pre-trained BERT downloads") help="Where to store pre-trained BERT downloads")
group.add_argument('--use-tfrecords', action='store_true',
help='load `--train-data`, `--valid-data`, '
'`--test-data` from BERT tf records instead of '
'normal data pipeline')
group.add_argument('--seq-length', type=int, default=512,
help="Maximum sequence length to process")
group.add_argument('--max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use per sequence.'
'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
'MUST BE SPECIFIED IF `--use-tfrecords` is True.')
return parser return parser
...@@ -355,7 +365,7 @@ def get_args(): ...@@ -355,7 +365,7 @@ def get_args():
args = parser.parse_args() args = parser.parse_args()
if not args.train_data and not args.train_data_path: if not args.train_data and not args.data_path:
print('WARNING: No training data specified') print('WARNING: No training data specified')
args.cuda = torch.cuda.is_available() args.cuda = torch.cuda.is_available()
......
...@@ -116,7 +116,7 @@ def make_tfrecord_loaders(args): ...@@ -116,7 +116,7 @@ def make_tfrecord_loaders(args):
def make_loaders(args): def make_loaders(args):
"""makes training/val/test""" """makes training/val/test"""
if args.use_tfrecords: if args.data_loader == 'tfrecords':
return make_tfrecord_loaders(args) return make_tfrecord_loaders(args)
world_size = torch.distributed.get_world_size( world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
...@@ -134,7 +134,7 @@ def make_loaders(args): ...@@ -134,7 +134,7 @@ def make_loaders(args):
data_set_args = { data_set_args = {
'path': args.train_data, 'path': args.train_data,
'seq_length': seq_length, 'seq_length': seq_length,
'lazy': args.lazy_loader, 'lazy': args.data_loader == 'lazy',
'delim': args.delim, 'delim': args.delim,
'text_key': args.text_key, 'text_key': args.text_key,
'label_key': 'label', 'label_key': 'label',
......
...@@ -56,9 +56,9 @@ def make_gpt2_dataloaders(args): ...@@ -56,9 +56,9 @@ def make_gpt2_dataloaders(args):
num_workers=num_workers, num_workers=num_workers,
pin_memory=True) pin_memory=True)
train = make_data_loader_(args.train_data_path) train = make_data_loader_(args.train_data)
valid = make_data_loader_(args.val_data_path) valid = make_data_loader_(args.val_data)
test = make_data_loader_(args.test_data_path) test = make_data_loader_(args.test_data)
args.do_train = False args.do_train = False
args.do_valid = False args.do_valid = False
......
from . import indexed_dataset from . import indexed_dataset
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .dataset import AlbertDataset from .albert_dataset import AlbertDataset
...@@ -29,8 +29,13 @@ class AlbertDataset(Dataset): ...@@ -29,8 +29,13 @@ class AlbertDataset(Dataset):
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
# Build the samples mapping. # Build the samples mapping.
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples or num_epochs")
num_epochs = int(max_num_samples / len(indexed_dataset)) + 1
if not max_num_samples: if not max_num_samples:
max_num_samples = len(indexed_dataset) * num_epochs max_num_samples = len(indexed_dataset) * num_epochs
print(f"Building the sample map for {num_epochs} epochs or {max_num_samples} samples.")
self.samples_mapping = helpers.build_mapping( self.samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
...@@ -52,12 +57,17 @@ class AlbertDataset(Dataset): ...@@ -52,12 +57,17 @@ class AlbertDataset(Dataset):
@classmethod @classmethod
def from_paths(cls, vocab, data_prefix, data_impl, def from_paths(cls, vocab, data_prefix, data_impl,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, skip_warmup=False):
tokenizer = FullBertTokenizer(vocab, do_lower_case=True) tokenizer = FullBertTokenizer(vocab, do_lower_case=True)
idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl) print("> Reading dataset index")
idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl, skip_warmup)
print("> Finished creating indexed dataset")
return cls(idx_ds, tokenizer, num_epochs, max_num_samples, masked_lm_prob, return cls(idx_ds, tokenizer, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed) max_seq_length, short_seq_prob, seed)
def num_tokens(self):
return self.tokenizer.vocab_size()
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
......
...@@ -357,7 +357,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -357,7 +357,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask. # Padding mask.
padding_mask = np.array([1]*num_tokens + [0]*padding_length, dtype=np.int64) padding_mask_np = np.array([1]*num_tokens + [0]*padding_length, dtype=np.int64)
# Lables and loss mask. # Lables and loss mask.
labels = [-1] * max_seq_length labels = [-1] * max_seq_length
...@@ -369,7 +369,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -369,7 +369,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
labels_np = np.array(labels, dtype=np.int64) labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64) loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels, padding_mask, loss_mask return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
......
...@@ -30,8 +30,9 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, ...@@ -30,8 +30,9 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
const double short_seq_prob, const double short_seq_prob,
const int seed) { const int seed) {
cout << "> building dataset mapping for " << docs_.shape(0) - 1 << cout << "> building dataset mapping for " << docs_.shape(0) - 1\
" documents with " << sizes_.shape(0) << " sentences ..." << endl; << " documents with " << sizes_.shape(0) << " sentences ..."
<< std::flush << endl;
// For efficiency, convert probability to ratio. // For efficiency, convert probability to ratio.
const auto short_seq_ratio = static_cast<int>(round(1.0 / short_seq_prob)); const auto short_seq_ratio = static_cast<int>(round(1.0 / short_seq_prob));
...@@ -72,8 +73,8 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, ...@@ -72,8 +73,8 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
// For each epoch: // For each epoch:
for (int epoch=0; epoch < num_epochs; ++epoch) { for (int epoch=0; epoch < num_epochs; ++epoch) {
if (map_index >= max_num_samples && !second) { if (map_index >= max_num_samples && !second) {
cout << " > reached " << max_num_samples << " samples after " << cout << " > reached " << max_num_samples << " samples after "
epoch << " epochs ..." << endl; << epoch << " epochs ..." << std::flush << endl;
break; break;
} }
// For each document: // For each document:
...@@ -96,8 +97,8 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, ...@@ -96,8 +97,8 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
empty_docs += 1; empty_docs += 1;
} }
if (num_remain_sent == 1) { if (num_remain_sent == 1) {
cout << "***WARNING*** document " << doc << // cout << "***WARNING*** document " << doc <<
" has one sentence" << endl; // " has one sentence" << endl;
one_sent_docs += 1; one_sent_docs += 1;
} }
} }
......
...@@ -51,13 +51,15 @@ def make_builder(out_file, impl, vocab_size=None): ...@@ -51,13 +51,15 @@ def make_builder(out_file, impl, vocab_size=None):
return IndexedDatasetBuilder(out_file) return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl): def make_dataset(path, impl, skip_warmup=False):
if impl == 'infer':
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path): if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path) return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path): elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path) return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path): elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path) return MMapIndexedDataset(path, skip_warmup)
return None return None
...@@ -315,7 +317,7 @@ class IndexedDatasetBuilder(object): ...@@ -315,7 +317,7 @@ class IndexedDatasetBuilder(object):
def _warmup_mmap_file(path): def _warmup_mmap_file(path):
with open(path, 'rb') as stream: with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024): while stream.read(1 * 1024 * 1024):
pass pass
...@@ -369,7 +371,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -369,7 +371,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return _Writer() return _Writer()
def __init__(self, path): def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream: with open(path, 'rb') as stream:
magic_test = stream.read(9) magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, ( assert self._HDR_MAGIC == magic_test, (
...@@ -387,13 +389,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -387,13 +389,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._doc_count = struct.unpack('<Q', stream.read(8))[0] self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell() offset = stream.tell()
_warmup_mmap_file(path) if not skip_warmup:
print("> Warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print("> Reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
print("> Reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes) offset=offset + self._sizes.nbytes)
print("> Reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes) offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self): def __del__(self):
...@@ -419,14 +426,14 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -419,14 +426,14 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self._len return self._len
def __init__(self, path): def __init__(self, path, skip_warmup=False):
super().__init__() super().__init__()
self._path = None self._path = None
self._index = None self._index = None
self._bin_buffer = None self._bin_buffer = None
self._do_init(path) self._do_init(path, skip_warmup)
def __getstate__(self): def __getstate__(self):
return self._path return self._path
...@@ -434,13 +441,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -434,13 +441,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __setstate__(self, state): def __setstate__(self, state):
self._do_init(state) self._do_init(state)
def _do_init(self, path): def _do_init(self, path, skip_warmup):
self._path = path self._path = path
self._index = self.Index(index_file_path(self._path)) self._index = self.Index(index_file_path(self._path), skip_warmup)
_warmup_mmap_file(data_file_path(self._path)) if not skip_warmup:
print("> Warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print("> Creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print("> Creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print("> Done")
def __del__(self): def __del__(self):
self._bin_buffer_mmap._mmap.close() self._bin_buffer_mmap._mmap.close()
...@@ -522,29 +534,3 @@ class MMapIndexedDatasetBuilder(object): ...@@ -522,29 +534,3 @@ class MMapIndexedDatasetBuilder(object):
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx) index.write(self._sizes, self._doc_idx)
class indexed_doc_dataset(torch.utils.data.Dataset):
def __init__(self, path):
impl = infer_dataset_impl(path)
self.ds = make_dataset(path, impl)
self._docs = []
doc_idxs = []
for i, s in enumerate(self._sizes):
if s > 0:
doc_idxs.append(i)
else:
self._docs.append(doc_idxs)
doc_idxs = []
def __getitem__(self, i):
if not isinstance(i, tuple):
raise ValueError("Index into indexed_doc_dataset must be a tuple")
idx = self._docs[i[0]][i[1]]
return self.ds[idx]
def __len__(self):
"""Returns number of documents, not number of sentences"""
return len(self._docs)
def doc_len(self, d):
return len(self._docs[d])
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dataset to split one large one into multiple smaller datasets"""
import torch
import numpy as np
def should_split(split):
"""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split)/sum(split) != 1.
def get_split(args):
"""
Get dataset splits from comma separated string list
"""
splits = []
if args.split.find(',') != -1:
splits = [float(s) for s in args.split.split(',')]
elif args.split.find('/') != -1:
splits = [float(s) for s in args.split.split('/')]
else:
splits = [float(args.split)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1-split_total)
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
if args.valid_data is not None:
splits[1] = 0.
if args.test_data is not None:
splits[2] = 0.
final_sum = sum(splits)
return [s/final_sum for s in splits]
class SplitDataset(torch.utils.data.Dataset):
"""
Dataset wrapper to access a subset of another dataset.
Purpose: useful to index into existing datasets, possibly
large-scale datasets as the subindexing operation is done in an
on-the-fly manner.
Arguments:
ds (Dataset or array-like): List of datasets to be subindexed
split_inds (1D array-like): List of indices part of subset
"""
def __init__(self, ds, split_inds, **kwargs):
self.split_inds = list(split_inds)
self.wrapped_data = ds
def __len__(self):
return len(self.split_inds)
def __getitem__(self, index):
return self.wrapped_data[self.split_inds[index]]
def num_tokens(self):
return self.wrapped_data.num_tokens()
def __iter__(self):
for idx in self.split_inds:
yield self.wrapped_data[idx]
def split_ds(ds, split=[.8,.2,.0], shuffle=True):
"""
Split a dataset into subsets given proportions of how
much to allocate per split. If a split is 0% returns None for that split.
Purpose: Useful for creating train/val/test splits
Arguments:
ds (Dataset or array-like): Data to be split.
split (1D array-like): proportions to split `ds`. `sum(splits) != 0`
shuffle (boolean): Randomly split dataset. Default: True
"""
split_sum = sum(split)
if split_sum == 0:
raise Exception('Split cannot sum to 0.')
split = np.array(split)
split /= split_sum
ds_len = len(ds)
inds = np.arange(ds_len)
if shuffle:
np.random.shuffle(inds)
start_idx = 0
residual_idx = 0
rtn_ds = [None]*len(split)
for i, f in enumerate(split):
if f != 0:
proportion = ds_len*split[i]
residual_idx += proportion % 1
split_ = int(int(proportion) + residual_idx)
split_inds = inds[start_idx:start_idx+max(split_, 1)]
rtn_ds[i] = SplitDataset(ds, split_inds)
start_idx += split_
residual_idx %= 1
return rtn_ds
...@@ -32,13 +32,37 @@ def should_split(split): ...@@ -32,13 +32,37 @@ def should_split(split):
""" """
given split proportions checks if should split given split proportions checks if should split
Examples: Examples:
>>> should_split([10,0,0]) >>> should_split([10,0,0])
False False
>>> should_split([1,.1,.2]) >>> should_split([1,.1,.2])
True True
""" """
return max(split)/sum(split) != 1. return max(split)/sum(split) != 1.
def get_split(args):
"""
Get dataset splits from comma separated string list
"""
splits = []
if args.split.find(',') != -1:
splits = [float(s) for s in args.split.split(',')]
elif args.split.find('/') != -1:
splits = [float(s) for s in args.split.split('/')]
else:
splits = [float(args.split)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1-split_total)
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
if args.valid_data is not None:
splits[1] = 0.
if args.test_data is not None:
splits[2] = 0.
final_sum = sum(splits)
return [s/final_sum for s in splits]
def get_ext(path): def get_ext(path):
"""gets path extension""" """gets path extension"""
return os.path.splitext(path)[1] return os.path.splitext(path)[1]
...@@ -108,7 +132,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -108,7 +132,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds = ConcatDataset(datasets) ds = ConcatDataset(datasets)
# make tokenizer for dataset # make tokenizer for dataset
if tokenizer is None: if tokenizer is None:
tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type, tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, **kwargs) pad_token, character_converage, **kwargs)
ds_type = '' ds_type = ''
......
...@@ -381,7 +381,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -381,7 +381,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT"""
import torch
import torch.nn.functional as F
from configure_data import configure_data
from megatron import mpu
from megatron.model import BertModel
from megatron.utils import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.data import AlbertDataset, split_dataset
from megatron.data_utils.samplers import DistributedBatchSampler
def model_provider(args):
"""Build the model."""
print_rank_0('building BERT model ...')
model = BertModel(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size,
parallel_output=True)
return model
def get_batch(data_iterator, timers):
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].byte()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, args, timers):
"""Forward step."""
# Get the batch.
timers('batch generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator, timers)
timers('batch generator').stop()
# Forward model.
lm_logits, sop_logits = model(tokens, 1-padding_mask, tokentype_ids=types)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
sentence_order.view(-1).contiguous(),
ignore_index=-1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
lm_labels.contiguous())
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
if args.data_loader == None:
args.data_loader = 'binary'
if args.data_loader == 'binary':
if not args.max_num_samples:
args.max_num_samples = (args.train_iters + 2 * args.eval_iters) * args.batch_size
if not args.data_path:
print("Albert currently only supports a unified dataset specified with --data-path")
exit(1)
print("Creating AlbertDataset...")
full_data = AlbertDataset.from_paths(args.vocab, args.data_path,
args.data_impl, args.data_epochs,
args.max_num_samples,
args.mask_prob, args.seq_length,
args.short_seq_prob,
args.seed, args.skip_mmap_warmup)
print("Finished creating AlbertDataset...")
split = split_dataset.get_split(args)
if split_dataset.should_split(split):
train_ds, val_ds, test_ds = split_dataset.split_ds(full_data, split, args.shuffle)
else:
train_ds = full_data
num_tokens = train_ds.num_tokens()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(val_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens for now
int(do_train),
int(do_valid),
int(do_test)])
else:
print("Unsupported data loader for BERT.")
exit(1)
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item()
num_type_tokens = token_counts[1].item()
args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
args.vocab_size = num_tokens
args.tokentype_size = num_type_tokens
return train_data, val_data, test_data
if __name__ == "__main__":
run('Pretrain BERT model', get_train_val_test_data,
model_provider, forward_step)
...@@ -112,17 +112,23 @@ def get_train_val_test_data(args): ...@@ -112,17 +112,23 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
data_config = configure_data() if (args.data_loader == 'raw'
ds_type = 'BERT' or args.data_loader == 'lazy'
data_config.set_defaults(data_set_type=ds_type, transpose=False) or args.data_loader == 'tfrecords'):
(train_data, val_data, test_data), tokenizer = data_config.apply(args) data_config = configure_data()
num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args) ds_type = 'BERT'
# Need to broadcast num_tokens and num_type_tokens. data_config.set_defaults(data_set_type=ds_type, transpose=False)
token_counts = torch.cuda.LongTensor([num_tokens, (train_data, val_data, test_data), tokenizer = data_config.apply(args)
tokenizer.num_type_tokens, num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args)
int(args.do_train), # Need to broadcast num_tokens and num_type_tokens.
int(args.do_valid), token_counts = torch.cuda.LongTensor([num_tokens,
int(args.do_test)]) tokenizer.num_type_tokens,
int(args.do_train),
int(args.do_valid),
int(args.do_test)])
else:
print("Unsupported data loader for BERT.")
exit(1)
else: else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
......
...@@ -168,10 +168,10 @@ def get_train_val_test_data(args): ...@@ -168,10 +168,10 @@ def get_train_val_test_data(args):
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
if args.use_npy_data_loader: if args.data_loader == 'numpy':
(train_data, val_data, test_data), num_tokens, \ (train_data, val_data, test_data), num_tokens, \
eod_token = make_gpt2_dataloaders(args) eod_token = make_gpt2_dataloaders(args)
else: elif args.data_loader == 'raw' or args.data_loader == 'tfrecords'
data_config = configure_data() data_config = configure_data()
data_config.set_defaults(data_set_type='GPT2', transpose=False) data_config.set_defaults(data_set_type='GPT2', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply( (train_data, val_data, test_data), tokenizer = data_config.apply(
......
#!/bin/bash
RANK=0
WORLD_SIZE=1
python pretrain_albert.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 10000 \
--save checkpoints/albert_117m \
--load checkpoints/albert_117m \
--resume-dataloader \
--data-path data/megatron/bc_rn_owt_sto_wiki_dedup_shuf_cleaned_0.7_mmap \
--vocab data/megatron/vocab.txt \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding \
--skip-mmap-warmup \
--num-workers 0
#!/bin/bash
GPUS_PER_NODE=2
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_albert.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 10000 \
--save checkpoints/albert_117m \
--load checkpoints/albert_117m \
--resume-dataloader \
--data-path data/megatron/bc_rn_owt_sto_wiki_dedup_shuf_cleaned_0.7_mmap \
--vocab data/megatron/vocab.txt \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding \
--skip-mmap-warmup \
--num-workers 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment