Commit 3f58649b authored by Mohammad's avatar Mohammad
Browse files

utils partially refactored and tested

parent 86e7d624
...@@ -36,7 +36,7 @@ from megatron.utils import Timers ...@@ -36,7 +36,7 @@ from megatron.utils import Timers
from megatron.utils import load_checkpoint from megatron.utils import load_checkpoint
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.utils import print_params_min_max_norm from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0 from megatron import print_rank_0
from megatron.data_utils import make_tokenizer from megatron.data_utils import make_tokenizer
......
...@@ -37,7 +37,7 @@ from megatron import mpu ...@@ -37,7 +37,7 @@ from megatron import mpu
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.model import DistributedDataParallel as DDP from megatron.model import DistributedDataParallel as DDP
from megatron.utils import print_rank_0 from megatron import print_rank_0
def get_model(args): def get_model(args):
"""Build the model.""" """Build the model."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .global_vars import get_args
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
...@@ -24,8 +24,8 @@ import torch ...@@ -24,8 +24,8 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu from megatron import mpu
from .global_vars import get_args from megatron import get_args
from .utils import print_rank_0 from megatron import print_rank_0
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
......
...@@ -27,7 +27,7 @@ from megatron.data import helpers ...@@ -27,7 +27,7 @@ from megatron.data import helpers
from megatron.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer from megatron.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0 from megatron import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
......
...@@ -18,7 +18,7 @@ from itertools import accumulate ...@@ -18,7 +18,7 @@ from itertools import accumulate
import numpy as np import numpy as np
import torch import torch
from megatron.utils import print_rank_0 from megatron import print_rank_0
def __best_fitting_dtype(vocab_size=None): def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500: if vocab_size is not None and vocab_size < 65500:
......
...@@ -21,11 +21,11 @@ import numpy as np ...@@ -21,11 +21,11 @@ import numpy as np
import torch import torch
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from .global_vars import get_adlr_autoresume from megatron.global_vars import set_global_variables
from .global_vars import get_args
from .global_vars import get_tensorboard_writer
from .global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}): def initialize_megatron(extra_args_provider=None, args_defaults={}):
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
import math import math
from megatron.utils import print_rank_0 from megatron import print_rank_0
class AnnealingLR(_LRScheduler): class AnnealingLR(_LRScheduler):
......
...@@ -25,7 +25,7 @@ from megatron.model.utils import get_linear_layer ...@@ -25,7 +25,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.utils import print_rank_0 from megatron import print_rank_0
class Classification(MegatronModule): class Classification(MegatronModule):
......
...@@ -25,7 +25,7 @@ from megatron.model.utils import get_linear_layer ...@@ -25,7 +25,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.utils import print_rank_0 from megatron import print_rank_0
class MultipleChoice(MegatronModule): class MultipleChoice(MegatronModule):
......
...@@ -35,10 +35,10 @@ from megatron.learning_rates import AnnealingLR ...@@ -35,10 +35,10 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.utils import print_rank_0 from megatron import print_rank_0
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.utils import save_checkpoint from megatron.checkpointing import save_checkpoint
def run(top_level_message, train_val_test_data_provider, def run(top_level_message, train_val_test_data_provider,
...@@ -108,8 +108,7 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -108,8 +108,7 @@ def run(top_level_message, train_val_test_data_provider,
timers, False) timers, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, save_checkpoint(iteration, model, optimizer, lr_scheduler)
lr_scheduler, args)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
...@@ -220,7 +219,7 @@ def setup_model_and_optimizer(model_provider_func, args): ...@@ -220,7 +219,7 @@ def setup_model_and_optimizer(model_provider_func, args):
lr_scheduler = get_learning_rate_scheduler(optimizer, args) lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.load is not None: if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
else: else:
args.iteration = 0 args.iteration = 0
...@@ -378,12 +377,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -378,12 +377,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if (iteration % args.adlr_autoresume_interval == 0) and \ if (iteration % args.adlr_autoresume_interval == 0) and \
args.adlr_autoresume: args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args) lr_scheduler)
# Checkpointing # Checkpointing
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
......
...@@ -13,31 +13,19 @@ ...@@ -13,31 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities for logging and serialization""" """General utilities."""
import os import sys
import random
import time
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.global_vars import get_args import torch
from megatron.global_vars import get_adlr_autoresume
from megatron import mpu from megatron import get_args
from megatron import get_adlr_autoresume
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def reduce_losses(losses): def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs.""" """Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat( reduced_losses = torch.cat(
...@@ -81,20 +69,27 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -81,20 +69,27 @@ def print_params_min_max_norm(optimizer, iteration):
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True) print(string, flush=True)
#######################################
def check_adlr_autoresume_termination(iteration, model, optimizer, def check_adlr_autoresume_termination(iteration, model,
lr_scheduler, args): optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received."""
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
torch.distributed.barrier() torch.distributed.barrier()
if args.AutoResume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args) save_checkpoint(iteration, model, optimizer, lr_scheduler)
print_rank_0(">>> autoresume termination request found!") print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
args.AutoResume.request_resume() autoresume.request_resume()
print_rank_0(">>> training terminated. Returning") print_rank_0(">>> training terminated. Returning")
exit(0) sys.exit(0)
###################################################
from megatron import mpu
def get_ltor_masks_and_position_ids(data, def get_ltor_masks_and_position_ids(data,
...@@ -166,239 +161,3 @@ def vocab_size_with_padding(num_tokens, args): ...@@ -166,239 +161,3 @@ def vocab_size_with_padding(num_tokens, args):
num_tokens, after - num_tokens, after)) num_tokens, after - num_tokens, after))
return after return after
def get_checkpoint_name(checkpoints_path, iteration, release=False,
mp_rank=None):
if release:
d = 'release'
else:
d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
else mp_rank),
'model_optim_rng.pt')
def ensure_directory_exists(filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args):
"""Save a model checkpoint."""
# Only rank zer0 of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
sd = {}
sd['iteration'] = iteration
sd['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
sd['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
sd['lr_scheduler'] = lr_scheduler.state_dict()
# rng states.
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
ensure_directory_exists(checkpoint_name)
torch.save(sd, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, args):
"""Load a model checkpoint."""
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
sd = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
import sys
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
sd = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except:
print_rank_0('could not load the checkpoint')
exit()
# Iterations.
if args.finetune or release:
iteration = 0
else:
try:
iteration = sd['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = sd['total_iters']
except KeyError:
print_rank_0('A metadata file exists but Unable to load iteration '
' from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Model.
try:
model.load_state_dict(sd['model'])
except KeyError:
print_rank_0('A metadata file exists but unable to load model '
'from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(sd['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(sd['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting.'
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
exit()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
def load_weights(src, dst, dst2src=False):
"""
Loads weights from src to dst via in place copy.
src is a huggingface gpt2model, while dst is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src is still untested
"""
conv_layer = 'Conv1D' in str(type(src))
for n, p in src.named_parameters():
if dst2src:
data = dst._parameters[n].data
load = p.data
else:
data = p.data
load = dst._parameters[n].data
if conv_layer and 'weight' in n:
data = data.t().contiguous()
load.copy_(data)
# dst._parameters[n].data.copy_(data)
def load_mlp(our, oai, dst2src=False):
load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
def load_attention(our, oai, dst2src=False):
load_weights(oai.c_attn, our.query_key_value, dst2src)
load_weights(oai.c_proj, our.dense, dst2src)
def load_transformer_layer(our, oai, dst2src=False):
load_weights(oai.ln_1, our.input_layernorm, dst2src)
load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
load_mlp(our.mlp, oai.mlp, dst2src)
load_attention(our.attention, oai.attn, dst2src)
def move_weights(our, oai, dst2src=False):
"""
Loads weights from `oai` to `our` via in place copy.
`oai` is a huggingface gpt2model, while `our` is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src=True is still untested
"""
# while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)):
# our=our.module
transformer_model = oai.transformer
load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src)
load_weights(transformer_model.wte, our.word_embeddings, dst2src)
load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src)
def merge_parallel_state_dicts(state_dicts):
temp_sd = {}
for sd in state_dicts:
for k, v in sd.items():
temp_sd[k].append()
pass
def merge_parallel_checkpoints(checkpoint_dir, model_parallel_size):
pass
...@@ -20,7 +20,7 @@ import torch.nn.functional as F ...@@ -20,7 +20,7 @@ import torch.nn.functional as F
from megatron import mpu from megatron import mpu
from megatron.model import BertModel from megatron.model import BertModel
from megatron.utils import print_rank_0 from megatron import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
......
...@@ -22,7 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders ...@@ -22,7 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu from megatron import mpu
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import print_rank_0 from megatron import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
......
...@@ -21,7 +21,7 @@ import time ...@@ -21,7 +21,7 @@ import time
import torch import torch
from megatron import mpu from megatron import mpu
from megatron.utils import print_rank_0 from megatron import print_rank_0
from .finetune_utils import build_data_loader from .finetune_utils import build_data_loader
from .finetune_utils import process_batch from .finetune_utils import process_batch
......
...@@ -26,7 +26,7 @@ from megatron.training import train_step ...@@ -26,7 +26,7 @@ from megatron.training import train_step
from megatron.training import training_log from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import load_checkpoint from megatron.utils import load_checkpoint
from megatron.utils import print_rank_0 from megatron import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import save_checkpoint from megatron.utils import save_checkpoint
......
...@@ -20,7 +20,7 @@ from abc import abstractmethod ...@@ -20,7 +20,7 @@ from abc import abstractmethod
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import build_sample from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_text from tasks.data_utils import build_tokens_types_paddings_from_text
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""GLUE finetuning/evaluation.""" """GLUE finetuning/evaluation."""
from megatron.utils import print_rank_0 from megatron import print_rank_0
from megatron.model.classification import Classification from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""MNLI dataset.""" """MNLI dataset."""
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import clean_text from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset from .data import GLUEAbstractDataset
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""QQP dataset.""" """QQP dataset."""
from megatron.utils import print_rank_0 from megatron import print_rank_0
from tasks.data_utils import clean_text from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset from .data import GLUEAbstractDataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment