"examples/vscode:/vscode.git/clone" did not exist on "2b04ec2ff7270d2044410378b04d85a194fa3d4a"
Commit 2d98cfbf authored by Neel Kant's avatar Neel Kant
Browse files

Merge staging-realm into realm-mlm

parents 8b1da95a 4abd7ce2
...@@ -42,6 +42,7 @@ def infer_dataset_impl(path): ...@@ -42,6 +42,7 @@ def infer_dataset_impl(path):
else: else:
return None return None
else: else:
print(f"Dataset path does not exist: {path}")
return None return None
...@@ -61,6 +62,7 @@ def make_dataset(path, impl, skip_warmup=False): ...@@ -61,6 +62,7 @@ def make_dataset(path, impl, skip_warmup=False):
return IndexedCachedDataset(path) return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path): elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup) return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None return None
...@@ -466,9 +468,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -466,9 +468,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
ptr, size = self._index[idx] ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
if self._index.dtype != np.int64: count=size, offset=ptr)
np_array = np_array.astype(np.int64)
return np_array return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
...@@ -478,10 +479,25 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -478,10 +479,25 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self._index._sizes[idx] sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes)) offsets = list(accumulate(sizes))
total_size = sum(sizes) total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1]) sents = np.split(np_array, offsets[:-1])
return sents return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property @property
def sizes(self): def sizes(self):
return self._index.sizes return self._index.sizes
......
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
import argparse import argparse
import os import os
import sys import sys
...@@ -7,52 +11,90 @@ import torch ...@@ -7,52 +11,90 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../")) sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.data import indexed_dataset, FullBertTokenizer, AlbertDataset from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
def test_indexed_dataset(args): def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) tokenizer = build_tokenizer(args)
print(len(ds.doc_idx)) print(len(ds.doc_idx))
print(len(ds)) print(len(ds))
print(ds.doc_idx[-1]) print(ds.doc_idx[-1])
if ds.supports_prefetch: if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small) # just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds))) ds.prefetch(range(len(ds)))
for i in range(len(ds.doc_idx)-1): if args.count > len(ds.doc_idx)-1:
args.count = len(ds.doc_idx)-1
for i in range(args.count):
start = ds.doc_idx[i] start = ds.doc_idx[i]
end = ds.doc_idx[i+1] end = ds.doc_idx[i+1]
ids = ds[start:end] ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids: for s in ids:
assert len(s) > 0 assert len(s) > 0
l = s.data.tolist() l = s.data.tolist()
tokens = tokenizer.convert_ids_to_tokens(l) text = tokenizer.detokenize(l)
for t in tokens: print(text)
if '\n' in t: print("---")
print("Newline in string!")
print(i) def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
def test_albert_dataset(args): tokenizer = build_tokenizer(args)
# tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) size = ds.sizes[0]
# idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) print(f"size: {size}")
# ds = AlbertDataset(idataset, tokenizer) full = ds.get(0)
ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, print(full)
args.epochs, args.max_num_samples, #print(tokenizer.detokenize(full.data.tolist()))
args.masked_lm_prob, args.seq_length, print("---")
args.short_seq_prob, args.seed) end = ds.get(0, offset=size-10)
truncated = 0 print(end)
total = 0 #print(tokenizer.detokenize(end.data.tolist()))
for s in ds:
ids = s['text'] start = ds.get(0, length=10)
tokens = ds.tokenizer.convert_ids_to_tokens(ids) print(start)
print(tokens) #print(tokenizer.detokenize(start.data.tolist()))
exit()
part = ds.get(0, offset=2, length=8)
print(part)
#print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--dataset-impl', type=str, default='infer', parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer']) choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5, parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for') help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None, parser.add_argument('--max-num-samples', type=int, default=None,
...@@ -66,12 +108,15 @@ def main(): ...@@ -66,12 +108,15 @@ def main():
parser.add_argument('--seed', type=int, default=1234, parser.add_argument('--seed', type=int, default=1234,
help='random seed') help='random seed')
args = parser.parse_args() args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
test_albert_dataset(args) # test_albert_dataset(args)
# test_indexed_dataset(args) test_indexed_dataset_get(args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset): ...@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset):
sample = { sample = {
'input_text': np.array(input_tokens), 'input_text': np.array(input_tokens),
'input_types': np.array(input_token_types), 'query_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask), 'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens), 'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types), 'block_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask) 'context_pad_mask': np.array(context_pad_mask)
} }
......
...@@ -65,7 +65,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}): ...@@ -65,7 +65,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults) defaults=args_defaults)
_build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
...@@ -85,6 +85,13 @@ def _build_tokenizer(args): ...@@ -85,6 +85,13 @@ def _build_tokenizer(args):
global _GLOBAL_TOKENIZER global _GLOBAL_TOKENIZER
_ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
_GLOBAL_TOKENIZER = build_tokenizer(args) _GLOBAL_TOKENIZER = build_tokenizer(args)
return _GLOBAL_TOKENIZER
def rebuild_tokenizer(args):
global _GLOBAL_TOKENIZER
_GLOBAL_TOKENIZER = None
return _build_tokenizer(args)
def _set_tensorboard_writer(args): def _set_tensorboard_writer(args):
......
...@@ -59,6 +59,7 @@ def _initialize_distributed(): ...@@ -59,6 +59,7 @@ def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and mpu."""
args = get_args() args = get_args()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if args.rank == 0: if args.rank == 0:
...@@ -66,23 +67,25 @@ def _initialize_distributed(): ...@@ -66,23 +67,25 @@ def _initialize_distributed():
'skipping initialization ...', flush=True) 'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
device = torch.cuda.current_device() if device_count > 0:
local_rank = args.rank % torch.cuda.device_count() device = torch.cuda.current_device()
assert local_rank == device, \ local_rank = args.rank % device_count
'expected local-rank to be the same as rank % device-count.' assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else: else:
if args.rank == 0: if args.rank == 0:
print('> initializing torch distributed ...', flush=True) print('> initializing torch distributed ...', flush=True)
# Manually set the device ids. # Manually set the device ids.
device = args.rank % torch.cuda.device_count() if device_count > 0:
if args.local_rank is not None: device = args.rank % device_count
assert args.local_rank == device, \ if args.local_rank is not None:
'expected local-rank to be the same as rank % device-count.' assert args.local_rank == device, \
else: 'expected local-rank to be the same as rank % device-count.'
args.local_rank = device else:
torch.cuda.set_device(device) args.local_rank = device
torch.cuda.set_device(device)
# Call the init process # Call the init process
init_method = 'tcp://' init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost') master_ip = os.getenv('MASTER_ADDR', 'localhost')
...@@ -94,7 +97,8 @@ def _initialize_distributed(): ...@@ -94,7 +97,8 @@ def _initialize_distributed():
init_method=init_method) init_method=init_method)
# Set the model-parallel / data-parallel communicators. # Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size) if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
def _init_autoresume(): def _init_autoresume():
...@@ -112,7 +116,8 @@ def _set_random_seed(seed): ...@@ -112,7 +116,8 @@ def _set_random_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed) if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -219,6 +219,7 @@ class BertModel(MegatronModule): ...@@ -219,6 +219,7 @@ class BertModel(MegatronModule):
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self, def __init__(self,
ict_head_size, ict_head_size,
num_tokentypes=2, num_tokentypes=2,
...@@ -231,41 +232,38 @@ class ICTBertModel(MegatronModule): ...@@ -231,41 +232,38 @@ class ICTBertModel(MegatronModule):
parallel_output=parallel_output parallel_output=parallel_output
) )
self.question_model = BertModel(**bert_args) # this model embeds (pseudo-)queries - Embed_input in the paper
self._question_key = 'question_model' self.query_model = BertModel(**bert_args)
self.context_model = BertModel(**bert_args) self._query_key = 'question_model'
self._context_key = 'context_model'
def forward(self, input_tokens, input_attention_mask, input_types, # this model embeds evidence blocks - Embed_doc in the paper
context_tokens, context_attention_mask, context_types, return_logits=False): self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types) def forward(self, query_tokens, query_attention_mask, query_types,
context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types) block_tokens, block_attention_mask, block_types):
"""Run a forward pass for each of the models and compute the similarity scores."""
# [batch x h] * [h x batch] query_logits, _ = self.query_model.forward(query_tokens, 1 - query_attention_mask, query_types)
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1)) block_logits, _ = self.block_model.forward(block_tokens, 1 - block_attention_mask, block_types)
if return_logits:
return question_ict_logits, context_ict_logits, retrieval_scores
return retrieval_scores
return query_logits, block_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._question_key] \ state_dict_[self._query_key] \
= self.question_model.state_dict_for_save_checkpoint( = self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._context_key] \ state_dict_[self._block_key] \
= self.context_model.state_dict_for_save_checkpoint( = self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Load the state dicts of each of the models"""
self.query_model.load_state_dict(
self.question_model.load_state_dict( state_dict[self._query_key], strict=strict)
state_dict[self._question_key], strict=strict) self.block_model.load_state_dict(
self.context_model.load_state_dict( state_dict[self._block_key], strict=strict)
state_dict[self._context_key], strict=strict)
...@@ -102,6 +102,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -102,6 +102,7 @@ class ParallelSelfAttention(MegatronModule):
output_layer_init_method, layer_number): output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
...@@ -244,7 +245,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -244,7 +245,7 @@ class ParallelSelfAttention(MegatronModule):
query_layer, key_layer) query_layer, key_layer)
# fp32 conversion. # fp32 conversion.
if self.attention_softmax_in_fp32: if self.fp16 and self.attention_softmax_in_fp32:
attention_scores = attention_scores.float() attention_scores = attention_scores.float()
# Apply attention mask. [b, np, s, s] # Apply attention mask. [b, np, s, s]
...@@ -267,7 +268,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -267,7 +268,7 @@ class ParallelSelfAttention(MegatronModule):
attention_probs = self._get_attention_probs(attention_scores) attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion # fp16 conversion
if self.attention_softmax_in_fp32: if self.fp16 and self.attention_softmax_in_fp32:
attention_probs = attention_probs.half() attention_probs = attention_probs.half()
# Context layer. [b, s, hp] # Context layer. [b, s, hp]
......
...@@ -37,11 +37,12 @@ from megatron.learning_rates import AnnealingLR ...@@ -37,11 +37,12 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
def pretrain(train_val_test_data_provider, model_provider, forward_step_func, def pretrain(train_valid_test_dataset_provider, model_provider,
extra_args_provider=None, args_defaults={}): forward_step_func, extra_args_provider=None, args_defaults={}):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -51,9 +52,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -51,9 +52,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
4) train the modle using the forward_step_func. 4) train the modle using the forward_step_func.
Arguments: Arguments:
train_val_test_data_provider: a function that builds datasets train_valid_test_dataset_provider: a function that takes the size of
and returns `train, val, test` dataloaders. train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp. model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator` and `model`, forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being and returns a `loss` scalar with a dictionary with key:values being
...@@ -78,35 +79,28 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -78,35 +79,28 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
timers('model and optimizer').stop() timers('model and optimizer').stop()
# Data stuff. # Data stuff.
timers('train/valid/test dataset').start() timers('train/valid/test data iterators').start()
train_data, val_data, test_data = train_val_test_data_provider() train_data_iterator, valid_data_iterator, test_data_iterator \
timers('train/valid/test dataset').stop() = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
# Train, validation, and test data. timers('train/valid/test data iterators').stop()
timers('train/valid/test dataloader').start()
train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data,
test_data)
timers('train/valid/test dataloader').stop()
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setups ...')
timers.log(['model and optimizer', 'train/valid/test dataset', timers.log(['model and optimizer', 'train/valid/test data iterators'])
'train/valid/test dataloader'])
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration, _ = train(forward_step_func, iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator) train_data_iterator, valid_data_iterator)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
if args.save and iteration != 0: if args.save and iteration != 0:
...@@ -151,8 +145,7 @@ def get_model(model_provider_func): ...@@ -151,8 +145,7 @@ def get_model(model_provider_func):
return model return model
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl)) 'Exiting.'.format(args.DDP_impl))
sys.exit()
def get_optimizer(model): def get_optimizer(model):
...@@ -269,19 +262,16 @@ def train_step(forward_step_func, data_iterator, ...@@ -269,19 +262,16 @@ def train_step(forward_step_func, data_iterator,
timers('forward').start() timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model) loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop() timers('forward').stop()
torch.cuda.synchronize()
# Calculate gradients, reduce across processes, and clip. # Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss) backward_step(optimizer, model, loss)
timers('backward').stop() timers('backward').stop()
torch.cuda.synchronize()
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
torch.cuda.synchronize()
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
...@@ -354,7 +344,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -354,7 +344,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator): train_data_iterator, valid_data_iterator):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -381,9 +371,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -381,9 +371,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration += 1 iteration += 1
# Logging. # Logging.
loss_scale = None
if args.fp16:
loss_scale = optimizer.loss_scale
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale, iteration, loss_scale,
report_memory_flag) report_memory_flag)
# Autoresume # Autoresume
...@@ -402,7 +395,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -402,7 +395,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args.do_valid: args.do_valid:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
...@@ -471,37 +464,87 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -471,37 +464,87 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0('-' * length) print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data): def build_train_valid_test_data_iterators(
"""Build train/validation/test iterators""" build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args() args = get_args()
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...')
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
# Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders.
train_dataloader = make_data_loader(train_ds)
valid_dataloader = make_data_loader(valid_ds)
test_dataloader = make_data_loader(test_ds)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0
do_valid = valid_dataloader is not None and args.eval_iters > 0
do_test = test_dataloader is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
# Shift the start iterations. # Shift the start iterations.
if train_data is not None: if train_dataloader is not None:
train_data.batch_sampler.start_iter = args.iteration % \ train_dataloader.batch_sampler.start_iter = args.iteration % \
len(train_data) len(train_dataloader)
print_rank_0('setting training data start iteration to {}'. print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter)) format(train_dataloader.batch_sampler.start_iter))
if val_data is not None: if valid_dataloader is not None:
start_iter_val = (args.iteration // args.eval_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters args.eval_iters
val_data.batch_sampler.start_iter = 0 valid_dataloader.batch_sampler.start_iter = start_iter_val % \
len(valid_dataloader)
print_rank_0('setting validation data start iteration to {}'. print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter)) format(valid_dataloader.batch_sampler.start_iter))
if train_data is not None: # Build iterators.
train_data_iterator = iter(train_data) if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else: else:
train_data_iterator = None train_data_iterator = None
if val_data is not None: if valid_dataloader is not None:
val_data_iterator = iter(val_data) valid_data_iterator = iter(valid_dataloader)
else: else:
val_data_iterator = None valid_data_iterator = None
if test_data is not None: if test_dataloader is not None:
test_data_iterator = iter(test_data) test_data_iterator = iter(test_dataloader)
else: else:
test_data_iterator = None test_data_iterator = None
return train_data_iterator, val_data_iterator, test_data_iterator return train_data_iterator, valid_data_iterator, test_data_iterator
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
import time
import os
import sys
from tokenizer import Tokenizer
def tokenize_corpus(filename, np_filename, print_interval=10000):
print(' > tokenizing {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
tokenized_docs = []
num_docs = 0
num_tokens = 0
start_time = time.time()
with open(filename, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
url = myjson['url']
sample = myjson['text']
tokens = tokenizer.tokenize_document(sample)
tokenized_docs.append(np.array(tokens, dtype=np.uint16))
num_docs += 1
num_tokens += len(tokens)
if num_docs % print_interval == 0:
print(' processed {:9d} documents in {:.2f} (s) so far'.
format(num_docs, time.time() - start_time),
flush=True)
except Exception as e:
print(' skipping ', line, e)
print(' >> processed {} document with total of {} tokens ...'.format(
num_docs, num_tokens))
tokenized_docs = np.array(tokenized_docs, dtype=object)
np.save(np_filename, tokenized_docs, allow_pickle=True)
print(' >> saved the tokenzed document to {} ...'.format(np_filename))
if __name__ == '__main__':
print('building gpt2 dataset ...')
path = sys.argv[1]
shard = sys.argv[2]
input_filename = os.path.join(path,
'shards/shard_{:04d}'.format(int(shard)))
output_filename = os.path.join(path,
'npys/shard_{:04d}.npy'.format(int(shard)))
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
tokenize_corpus(input_filename, output_filename)
import glob
import json
import os
import time
import sys
import numpy as np
if __name__ == '__main__':
print('building the shard sizes ...')
path = sys.argv[1]
print('> reading numpy files from {}'.format(path))
npy_files = glob.glob(path + '/*.npy')
npy_files.sort()
print(' found {} numpy files'.format(len(npy_files)))
size_dict = {}
counter = 0
start_time = time.time()
for filename in npy_files:
data = np.load(filename, allow_pickle=True)
size = np.hstack(data).size
np_filename = os.path.basename(filename)
size_dict[np_filename] = size
counter += 1
if counter % 10 == 0:
print(' processed {} files in {:.2f} seconds'.format(
counter, time.time() - start_time))
output_filename = os.path.join(path, 'sizes.txt')
with open(output_filename, 'w') as f:
json.dump(size_dict, f)
print('> wrote sizes to {}'.format(output_filename))
#!/bin/bash
echo "processing gpt2 data ..."
DIR="/raid/mpatwary/redownload_v0/0-21"
for thread in {0..3}; do
echo " launching thread "$thread && python make_gpt2_dataset.py $DIR $thread > $DIR/logs/shard_$thread.log 2>&1 &
done
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
from megatron.data_utils.tokenization_gpt2 import GPT2Tokenizer
class Tokenizer:
def __init__(self, cache_dir=None):
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir)
self.tokenizer.max_len = int(1e12)
self.eod_token = self.tokenizer.encoder['<|endoftext|>']
assert self.eod_token < 65535, 'vocab size will not fit in uint16'
print('> GPT2 tokenizer with {} vocab size and eod token {} ...'.format(
len(self.tokenizer.encoder), self.eod_token))
def tokenize_document(self, document):
tokens = self.tokenizer.encode(document)
tokens.append(self.eod_token)
return tokens
...@@ -25,13 +25,11 @@ from megatron import print_rank_0 ...@@ -25,13 +25,11 @@ from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
...@@ -44,6 +42,7 @@ def model_provider(): ...@@ -44,6 +42,7 @@ def model_provider():
def get_batch(data_iterator): def get_batch(data_iterator):
"""Build the batch."""
# Items and their type. # Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
...@@ -96,70 +95,28 @@ def forward_step(data_iterator, model): ...@@ -96,70 +95,28 @@ def forward_step(data_iterator, model):
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]} return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Build train, valid, and test datasets."""
args = get_args() args = get_args()
(train_data, valid_data, test_data) = (None, None, None) print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
# Data loader only on rank 0 of each model parallel group. train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
if mpu.get_model_parallel_rank() == 0: data_prefix=args.data_path,
print_rank_0('> building train, validation, and test datasets ' data_impl=args.data_impl,
'for BERT ...') splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
data_parallel_size = mpu.get_data_parallel_world_size() max_seq_length=args.seq_length,
data_parallel_rank = mpu.get_data_parallel_rank() masked_lm_prob=args.mask_prob,
global_batch_size = args.batch_size * data_parallel_size short_seq_prob=args.short_seq_prob,
seed=args.seed,
# Number of train/valid/test samples. skip_warmup=(not args.mmap_warmup))
train_iters = args.train_iters print_rank_0("> finished creating BERT datasets ...")
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, valid_data, test_data return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -25,7 +25,6 @@ from megatron import print_rank_0 ...@@ -25,7 +25,6 @@ from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel from megatron.model import ICTBertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
num_batches = 0 num_batches = 0
...@@ -46,8 +45,8 @@ def model_provider(): ...@@ -46,8 +45,8 @@ def model_provider():
def get_batch(data_iterator): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['input_text', 'input_types', 'input_pad_mask', keys = ['query_tokens', 'query_types', 'query_pad_mask',
'context_text', 'context_types', 'context_pad_mask'] 'block_tokens', 'block_types', 'block_pad_mask']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
...@@ -58,16 +57,15 @@ def get_batch(data_iterator): ...@@ -58,16 +57,15 @@ def get_batch(data_iterator):
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
input_tokens = data_b['input_text'].long() query_tokens = data_b['query_tokens'].long()
input_types = data_b['input_types'].long() query_types = data_b['query_types'].long()
input_pad_mask = data_b['input_pad_mask'].long() query_pad_mask = data_b['query_pad_mask'].long()
context_tokens = data_b['context_text'].long() block_tokens = data_b['block_tokens'].long()
context_types = data_b['context_types'].long() block_types = data_b['block_types'].long()
context_pad_mask = data_b['context_pad_mask'].long() block_pad_mask = data_b['block_pad_mask'].long()
context_indices = data_b['context_indices'].long()
return input_tokens, input_types, input_pad_mask,\ return query_tokens, query_types, query_pad_mask,\
context_tokens, context_types, context_pad_mask, context_indices block_tokens, block_types, block_pad_mask
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
...@@ -76,15 +74,18 @@ def forward_step(data_iterator, model): ...@@ -76,15 +74,18 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
input_tokens, input_types, input_pad_mask,\ query_tokens, query_types, query_pad_mask,\
context_tokens, context_types, context_pad_mask = get_batch(data_iterator) block_tokens, block_types, block_pad_mask = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
retrieval_scores = model(input_tokens, input_pad_mask, input_types, query_logits, block_logits = model(query_tokens, query_pad_mask, query_types,
context_tokens, context_pad_mask, context_types).float() block_tokens, block_pad_mask, block_types).float()
# [batch x h] * [h x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True) top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
batch_size = softmaxed.shape[0] batch_size = softmaxed.shape[0]
...@@ -99,71 +100,29 @@ def forward_step(data_iterator, model): ...@@ -99,71 +100,29 @@ def forward_step(data_iterator, model):
'top5_acc': reduced_losses[2]} 'top5_acc': reduced_losses[2]}
def get_train_val_test_data(): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Build train, valid and test datasets."""
args = get_args() args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
(train_data, valid_data, test_data) = (None, None, None) train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
# Data loader only on rank 0 of each model parallel group. data_impl=args.data_impl,
if mpu.get_model_parallel_rank() == 0: splits_string=args.split,
print_rank_0('> building train, validation, and test datasets ' train_valid_test_num_samples=train_val_test_num_samples,
'for BERT ...') max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
data_parallel_size = mpu.get_data_parallel_world_size() short_seq_prob=args.short_seq_prob,
data_parallel_rank = mpu.get_data_parallel_rank() seed=args.seed,
global_batch_size = args.batch_size * data_parallel_size skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
# Number of train/valid/test samples. print_rank_0("> finished creating BERT ICT datasets ...")
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_iters + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
print_rank_0("> finished creating BERT ICT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, valid_data, test_data return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
"""Pretrain GPT2""" """Pretrain GPT2"""
import os
import torch import torch
from megatron import get_args from megatron import get_args
...@@ -24,17 +22,15 @@ from megatron import get_timers ...@@ -24,17 +22,15 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True) model = GPT2Model(num_tokentypes=0, parallel_output=True)
...@@ -98,71 +94,26 @@ def forward_step(data_iterator, model): ...@@ -98,71 +94,26 @@ def forward_step(data_iterator, model):
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': reduced_loss[0]}
def make_gpt2_dataloaders(): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build gpt2 dataloders.""" """Build train, valid, and test datasets."""
args = get_args()
# Input parameters.
input_data_sizes_file = args.input_data_sizes_file
seq_length = args.seq_length
initial_seed = args.seed
# Build the datasets.
def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = _build_dataset('train')
valid_ds = _build_dataset('valid')
test_ds = _build_dataset('test')
# Dataloaders
train = make_data_loader(train_ds)
valid = make_data_loader(valid_ds)
test = make_data_loader(test_ds)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
return (train, valid, test)
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args() args = get_args()
(train_data, val_data, test_data) = (None, None, None) print_rank_0('> building train, validation, and test datasets '
'for GPT2 ...')
# Data loader only on rank 0 of each model parallel group. train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
if mpu.get_model_parallel_rank() == 0: data_prefix=args.data_path,
data_impl=args.data_impl,
(train_data, val_data, test_data) = make_gpt2_dataloaders() splits_string=args.split,
flags = torch.cuda.LongTensor([int(args.do_train), train_valid_test_num_samples=train_val_test_num_samples,
int(args.do_valid), seq_length=args.seq_length,
int(args.do_test)]) seed=args.seed,
else: skip_warmup=(not args.mmap_warmup))
flags = torch.cuda.LongTensor([0, 0, 0]) print_rank_0("> finished creating GPT2 datasets ...")
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, val_data, test_data return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment