Commit 6c521eb1 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging_args' into 'staging'

Major refactor: args, global variables, tokenizer

See merge request ADLR/megatron-lm!36
parents 37ff534f b07f1d0a
......@@ -36,7 +36,7 @@ from megatron.utils import Timers
from megatron.utils import load_checkpoint
from megatron.utils import report_memory
from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0
from megatron import print_rank_0
from megatron.data_utils import make_tokenizer
......
......@@ -37,7 +37,7 @@ from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP
from megatron.utils import print_rank_0
from megatron import print_rank_0
def get_model(args):
"""Build the model."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .global_vars import get_args
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input/output checkpointing."""
import os
import random
import sys
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
args = get_args()
def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):
"""A unified checkpoint name."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
else mp_rank),
'model_optim_rng.pt')
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler):
"""Load a model checkpoint and return the iteration."""
args = get_args()
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except:
print_rank_0('could not load the checkpoint')
sys.exit()
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
else:
print_rank_0('could not find arguments in the checkpoint ...')
# Model.
model.load_state_dict(state_dict['model'])
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
from . import indexed_dataset
from .bert_tokenization import FullTokenizer as FullBertTokenizer
......@@ -22,24 +22,19 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import mpu
from megatron.data import helpers
from megatron.data import FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0
from megatron import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
splits_string, train_valid_test_num_samples,
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup):
# Tokenizer is the same
tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
tokenizer.vocab_size()))
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
......@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
dataset = BertDataset(
name=name,
indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
......@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
......@@ -117,8 +111,7 @@ class BertDataset(Dataset):
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
# Tokenizer and dataset.
self.tokenizer = tokenizer
# Dataset.
self.indexed_dataset = indexed_dataset
......@@ -133,16 +126,13 @@ class BertDataset(Dataset):
self.name)
# Vocab stuff.
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self.pad_id = self.tokenizer.vocab['[PAD]']
def num_tokens(self):
return self.tokenizer.vocab_size()
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
......
......@@ -13,71 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 dataset."""
import json
import os
import numpy as np
import torch
from torch.multiprocessing import Lock
from torch.utils.data import Dataset
from megatron import mpu
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.data_utils.tokenization_gpt2 import GPT2Tokenizer
def make_gpt2_dataloaders(args):
# Input parameters.
input_data_sizes_file = args.input_data_sizes_file
seq_length = args.seq_length
initial_seed = args.seed
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
def make_data_loader_(data_path):
# Build the dataset.
dataset = GPT2Dataset(data_path, input_data_sizes_file,
seq_length, initial_seed)
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
train = make_data_loader_(args.train_data)
valid = make_data_loader_(args.valid_data)
test = make_data_loader_(args.test_data)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
# Tokenizer.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir)
eod_token = tokenizer.encoder['<|endoftext|>']
num_tokens = eod_token + 1
return (train, valid, test), num_tokens, eod_token
class GPT2Dataset(Dataset):
......@@ -89,8 +33,6 @@ class GPT2Dataset(Dataset):
self.seq_length = seq_length
self.initial_seed = initial_seed
self.max_epochs = max_epochs
# Lock for building the dataset.
self.lock = Lock()
# Shard stuff.
# Dictionary from shard nameto its size (number of element).
......@@ -120,13 +62,11 @@ class GPT2Dataset(Dataset):
# data index in the shard.
data_idx = idx - self.shards_start_index[shard_index]
# Load the shard if it is not in memory.
#self.lock.acquire()
if self.shards_data[shard_index] is None:
print('global rank {} is building data for shard index {} ...'.
format(torch.distributed.get_rank(), shard_index))
self.build_dataset_(shard_index)
#assert self.shards_data[shard_index] is not None
#self.lock.release()
# Start index.
start_index = self.shards_sample_index[shard_index][data_idx]
# Add one for label shift.
......@@ -194,18 +134,3 @@ class GPT2Dataset(Dataset):
size = self.shard_size_dict[shard]
self.shards_start_index[i] = self.shards_start_index[i-1] + \
size // self.seq_length
'''
if __name__ == '__main__':
print('gpt2 data loader ...')
path = '/raid/mshoeybi/data/gpt2/adlr/reddit_all_ftfy_lg200/npys'
dataset = GPT2Dataset(path, 'sizes.txt', 1024, 1234, 100)
print('dataset contains {} samples'.format(dataset.data_length))
for i in range(len(dataset)):
if i % 512000 == 0:
print(i)
data = dataset[i]
'''
......@@ -18,7 +18,7 @@ from itertools import accumulate
import numpy as np
import torch
from megatron.utils import print_rank_0
from megatron import print_rank_0
def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batch samplers that work with either random or sequential data samplers."""
import torch
from torch.utils import data
class RandomSampler(data.sampler.Sampler):
"""Based off of pytorch RandomSampler and DistributedSampler. Essentially
a RandomSampler, but this class lets the user set an epoch like
DistributedSampler Samples elements randomly. If without replacement, then
sample from a shuffled dataset. If with replacement, then user can
specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``,
default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not "
"be specified, since a random permute will be "
"performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(
self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,),
dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler):
"""Similar to normal implementation of distributed sampler, except
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last)
if rank == -1:
assert False, 'should not be here'
rank = torch.distributed.get_rank()
self.rank = rank
self.world_size = world_size
self.sampler.wrap_around = 0
self.wrap_around = 0
self.wrap_last = wrap_last
self.start_iter = 0
def __iter__(self):
batch = []
i = 0
for idx in self.data_iterator(self.sampler, wrap_around=False):
batch.append(idx)
if len(batch) == self.batch_size:
tbatch = self._batch(batch)
if i >= self.start_iter:
yield tbatch
self.start_iter = 0
i += 1
batch = []
batch_len = len(batch)
if batch_len > 0 and not self.drop_last:
if self.wrap_last:
self.sampler.wrap_around -= (self.batch_size)
self.wrap_around += (len(batch))
self.wrap_around %= self.batch_size
yield self._batch(batch)
if self.wrap_last:
self.sampler.wrap_around += self.batch_size
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around%self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
self.wrap_around %= self.batch_size
yield idx
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank*self.batch_size//self.world_size
end = (self.rank+1)*self.batch_size//self.world_size
return batch[start:end]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment