Commit 942c402d authored by Jared Casper's avatar Jared Casper
Browse files

Making loading arguments from checkpoint cleaner and available more broadly.

parent 06fc51ce
......@@ -25,22 +25,6 @@ from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
from .utils import (print_rank_0,
is_last_rank,
print_rank_last)
......@@ -20,8 +20,7 @@ import os
import torch
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False, validate=True):
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
......@@ -53,14 +52,13 @@ def parse_args(extra_args_provider=None, defaults={},
else:
args = parser.parse_args()
if validate:
return validate_args(args, defaults)
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def validate_args(args, defaults={}):
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
......@@ -628,6 +626,9 @@ def _add_checkpointing_args(parser):
'can reduce startup time when definitely loading from a '
'checkpoint',
dest='perform_initialization')
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
return parser
......
......@@ -22,11 +22,12 @@ import numpy as np
import torch
from megatron import (get_args,
mpu,
print_rank_0,
update_num_microbatches,
utils)
from megatron import (mpu,
update_num_microbatches)
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
_CHECKPOINT_VERSION = None
......@@ -207,7 +208,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
args = get_args()
# Only rank zero of the data parallel writes to the disk.
model = utils.unwrap_model(model)
model = unwrap_model(model)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
......@@ -386,8 +387,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
return state_dict, release
def load_args_from_checkpoint(args, load_arg='load'):
"""Set any arguments that are not currently set from the checkpoint
specified in the arguments.
"""Set required arguments from the checkpoint specified in the
arguments.
Will overwrite arguments that have a non-None default value, but
will leave any arguments that default to None as set.
Returns the same args NameSpace with the new values added/updated.
......@@ -406,6 +410,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
return args
if 'args' not in state_dict:
print('Checkpoint provided does not have arguments saved.')
return args
checkpoint_args = state_dict['args']
......@@ -422,7 +427,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None:
print(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
setattr(args, arg_name, checkpoint_value)
_set_arg('num_layers')
......@@ -453,7 +458,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
args = get_args()
load_dir = getattr(args, load_arg)
model = utils.unwrap_model(model)
model = unwrap_model(model)
state_dict, release = _load_base_checkpoint(load_dir, False)
......@@ -574,7 +579,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args = get_args()
model = utils.unwrap_model(model)
model = unwrap_model(model)
load_path = custom_load_path if custom_load_path is not None else args.load
......
......@@ -23,7 +23,6 @@ import torch
from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None
......@@ -86,16 +85,14 @@ def _set_signal_handler():
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, parse_args=True):
def set_global_variables(args):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
if parse_args:
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
else:
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
args = get_args()
assert args is not None
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
set_args(args)
_build_num_microbatches_calculator(args)
if args.vocab_file:
_ = _build_tokenizer(args)
......@@ -117,10 +114,9 @@ def _parse_args(extra_args_provider=None, defaults={},
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args,
validate=True)
_GLOBAL_ARGS = args
return _GLOBAL_ARGS
......
......@@ -28,6 +28,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
......@@ -47,11 +49,18 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
assert args.load is not None, '--use-checkpoints-args requires --load argument'
load_args_from_checkpoint(args)
validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
set_global_variables(args)
# torch.distributed initialization
def finish_mpu_init():
......
......@@ -72,16 +72,6 @@ class MegatronModule(torch.nn.Module):
if args.pipeline_model_parallel_size == 1:
return
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
......@@ -112,6 +102,16 @@ class MegatronModule(torch.nn.Module):
self.pre_process:
self.language_model.embedding.zero_parameters()
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Ensure that first and last stages have the same initial parameter
# values.
if mpu.is_rank_in_embedding_group():
......
......@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
......@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
......@@ -29,7 +29,7 @@ def _load_checkpoint(queue, args):
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType
from megatron.model import ModelType, module
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
......@@ -51,9 +51,15 @@ def _load_checkpoint(queue, args):
'--load', args.load_dir
]
margs = parse_args(validate=False)
margs = parse_args()
margs = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
def check_for_arg(arg_name):
if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
......@@ -71,13 +77,6 @@ def _load_checkpoint(queue, args):
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size}'
margs = validate_args(margs)
check_for_arg('params_dtype')
# Determine how to make our models
......@@ -90,6 +89,9 @@ def _load_checkpoint(queue, args):
else:
raise Exception(f'unrecognized model type: {args.model_type}')
# supress warning about torch.distributed not being initialized
module.MegatronModule.embedding_warning_printed = True
def get_models(count, dtype, pre_process, post_process):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
......@@ -105,14 +107,12 @@ def _load_checkpoint(queue, args):
models.append(model_[0])
return models
set_args(margs)
if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(parse_args=False)
set_global_variables(margs)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
......
......@@ -28,6 +28,7 @@ def save_checkpoint(queue, args):
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
......@@ -46,7 +47,6 @@ def save_checkpoint(queue, args):
md = queue_get()
if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
......@@ -102,7 +102,10 @@ def save_checkpoint(queue, args):
if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
set_global_variables()
margs = parse_args()
validate_args(margs)
set_global_variables(margs)
# margs = megatron args
margs = get_args()
......@@ -157,6 +160,7 @@ def save_checkpoint(queue, args):
else:
print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
margs.padded_vocab_size = orig_word_embed.shape[0]
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment