Commit 942c402d authored by Jared Casper's avatar Jared Casper
Browse files

Making loading arguments from checkpoint cleaner and available more broadly.

parent 06fc51ce
...@@ -25,22 +25,6 @@ from .global_vars import get_adlr_autoresume ...@@ -25,22 +25,6 @@ from .global_vars import get_adlr_autoresume
from .global_vars import get_timers from .global_vars import get_timers
from .initialize import initialize_megatron from .initialize import initialize_megatron
def print_rank_0(message): from .utils import (print_rank_0,
"""If distributed is initialized, print only on rank 0.""" is_last_rank,
if torch.distributed.is_initialized(): print_rank_last)
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
...@@ -20,8 +20,7 @@ import os ...@@ -20,8 +20,7 @@ import os
import torch import torch
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, ignore_unknown_args=False):
ignore_unknown_args=False, validate=True):
"""Parse all arguments.""" """Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments', parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False) allow_abbrev=False)
...@@ -53,14 +52,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -53,14 +52,13 @@ def parse_args(extra_args_provider=None, defaults={},
else: else:
args = parser.parse_args() args = parser.parse_args()
if validate: # Args from environment
return validate_args(args, defaults) args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args return args
def validate_args(args, defaults={}): def validate_args(args, defaults={}):
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Tensor model parallel size. # Tensor model parallel size.
args.tensor_model_parallel_size = min( args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size) args.tensor_model_parallel_size, args.world_size)
...@@ -628,6 +626,9 @@ def _add_checkpointing_args(parser): ...@@ -628,6 +626,9 @@ def _add_checkpointing_args(parser):
'can reduce startup time when definitely loading from a ' 'can reduce startup time when definitely loading from a '
'checkpoint', 'checkpoint',
dest='perform_initialization') dest='perform_initialization')
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
return parser return parser
......
...@@ -22,11 +22,12 @@ import numpy as np ...@@ -22,11 +22,12 @@ import numpy as np
import torch import torch
from megatron import (get_args, from megatron import (mpu,
mpu, update_num_microbatches)
print_rank_0, from .global_vars import get_args
update_num_microbatches, from .utils import (unwrap_model,
utils) print_rank_0)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
...@@ -207,7 +208,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -207,7 +208,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
model = utils.unwrap_model(model) model = unwrap_model(model)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
...@@ -386,8 +387,11 @@ def _load_base_checkpoint(load_dir, rank0=False): ...@@ -386,8 +387,11 @@ def _load_base_checkpoint(load_dir, rank0=False):
return state_dict, release return state_dict, release
def load_args_from_checkpoint(args, load_arg='load'): def load_args_from_checkpoint(args, load_arg='load'):
"""Set any arguments that are not currently set from the checkpoint """Set required arguments from the checkpoint specified in the
specified in the arguments. arguments.
Will overwrite arguments that have a non-None default value, but
will leave any arguments that default to None as set.
Returns the same args NameSpace with the new values added/updated. Returns the same args NameSpace with the new values added/updated.
...@@ -406,6 +410,7 @@ def load_args_from_checkpoint(args, load_arg='load'): ...@@ -406,6 +410,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
return args return args
if 'args' not in state_dict: if 'args' not in state_dict:
print('Checkpoint provided does not have arguments saved.')
return args return args
checkpoint_args = state_dict['args'] checkpoint_args = state_dict['args']
...@@ -422,7 +427,7 @@ def load_args_from_checkpoint(args, load_arg='load'): ...@@ -422,7 +427,7 @@ def load_args_from_checkpoint(args, load_arg='load'):
checkpoint_value = getattr(checkpoint_args, arg_name, None) checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None: if checkpoint_value is not None:
print(f"Setting {arg_name} to {checkpoint_value} from checkpoint") print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
setattr(args, arg_name, checkpoint_value) setattr(args, arg_name, checkpoint_value)
_set_arg('num_layers') _set_arg('num_layers')
...@@ -453,7 +458,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -453,7 +458,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
model = utils.unwrap_model(model) model = unwrap_model(model)
state_dict, release = _load_base_checkpoint(load_dir, False) state_dict, release = _load_base_checkpoint(load_dir, False)
...@@ -574,7 +579,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -574,7 +579,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
args = get_args() args = get_args()
model = utils.unwrap_model(model) model = unwrap_model(model)
load_path = custom_load_path if custom_load_path is not None else args.load load_path = custom_load_path if custom_load_path is not None else args.load
......
...@@ -23,7 +23,6 @@ import torch ...@@ -23,7 +23,6 @@ import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator from .microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
...@@ -86,16 +85,14 @@ def _set_signal_handler(): ...@@ -86,16 +85,14 @@ def _set_signal_handler():
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__() _GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
def set_global_variables(extra_args_provider=None, args_defaults={}, def set_global_variables(args):
ignore_unknown_args=False, parse_args=True):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
if parse_args:
args = _parse_args(extra_args_provider=extra_args_provider, assert args is not None
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args) _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
else: set_args(args)
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
args = get_args()
_build_num_microbatches_calculator(args) _build_num_microbatches_calculator(args)
if args.vocab_file: if args.vocab_file:
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
...@@ -117,10 +114,9 @@ def _parse_args(extra_args_provider=None, defaults={}, ...@@ -117,10 +114,9 @@ def _parse_args(extra_args_provider=None, defaults={},
"""Parse entire arguments.""" """Parse entire arguments."""
global _GLOBAL_ARGS global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args, _GLOBAL_ARGS = args
validate=True)
return _GLOBAL_ARGS return _GLOBAL_ARGS
......
...@@ -28,6 +28,8 @@ from megatron import get_adlr_autoresume ...@@ -28,6 +28,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank, from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size) set_tensor_model_parallel_world_size)
...@@ -47,11 +49,18 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -47,11 +49,18 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume, # Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
assert args.load is not None, '--use-checkpoints-args requires --load argument'
load_args_from_checkpoint(args)
validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(args)
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# torch.distributed initialization # torch.distributed initialization
def finish_mpu_init(): def finish_mpu_init():
......
...@@ -72,16 +72,6 @@ class MegatronModule(torch.nn.Module): ...@@ -72,16 +72,6 @@ class MegatronModule(torch.nn.Module):
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
return return
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Parameters are shared between the word embeddings layers, and the # Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
...@@ -112,6 +102,16 @@ class MegatronModule(torch.nn.Module): ...@@ -112,6 +102,16 @@ class MegatronModule(torch.nn.Module):
self.pre_process: self.pre_process:
self.language_model.embedding.zero_parameters() self.language_model.embedding.zero_parameters()
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_rank_in_embedding_group(): if mpu.is_rank_in_embedding_group():
......
...@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
...@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data, ...@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
...@@ -29,7 +29,7 @@ def _load_checkpoint(queue, args): ...@@ -29,7 +29,7 @@ def _load_checkpoint(queue, args):
from megatron.arguments import parse_args, validate_args from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType from megatron.model import ModelType, module
from megatron import mpu, fused_kernels from megatron import mpu, fused_kernels
except ModuleNotFoundError: except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
...@@ -51,9 +51,15 @@ def _load_checkpoint(queue, args): ...@@ -51,9 +51,15 @@ def _load_checkpoint(queue, args):
'--load', args.load_dir '--load', args.load_dir
] ]
margs = parse_args(validate=False) margs = parse_args()
margs = load_args_from_checkpoint(margs) margs = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
def check_for_arg(arg_name): def check_for_arg(arg_name):
if getattr(margs, arg_name, None) is None: if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
...@@ -71,13 +77,6 @@ def _load_checkpoint(queue, args): ...@@ -71,13 +77,6 @@ def _load_checkpoint(queue, args):
check_for_arg('tokenizer_type') check_for_arg('tokenizer_type')
check_for_arg('iteration') check_for_arg('iteration')
check_for_arg('bert_binary_head') check_for_arg('bert_binary_head')
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size}'
margs = validate_args(margs)
check_for_arg('params_dtype') check_for_arg('params_dtype')
# Determine how to make our models # Determine how to make our models
...@@ -90,6 +89,9 @@ def _load_checkpoint(queue, args): ...@@ -90,6 +89,9 @@ def _load_checkpoint(queue, args):
else: else:
raise Exception(f'unrecognized model type: {args.model_type}') raise Exception(f'unrecognized model type: {args.model_type}')
# supress warning about torch.distributed not being initialized
module.MegatronModule.embedding_warning_printed = True
def get_models(count, dtype, pre_process, post_process): def get_models(count, dtype, pre_process, post_process):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor: # with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)] # futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
...@@ -105,14 +107,12 @@ def _load_checkpoint(queue, args): ...@@ -105,14 +107,12 @@ def _load_checkpoint(queue, args):
models.append(model_[0]) models.append(model_[0])
return models return models
set_args(margs)
if margs.num_layers_per_virtual_pipeline_stage is not None: if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.") print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit") queue.put("exit")
exit(1) exit(1)
set_global_variables(parse_args=False) set_global_variables(margs)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs) fused_kernels.load(margs)
......
...@@ -28,6 +28,7 @@ def save_checkpoint(queue, args): ...@@ -28,6 +28,7 @@ def save_checkpoint(queue, args):
sys.path.insert(0, args.megatron_path) sys.path.insert(0, args.megatron_path)
try: try:
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType from megatron.model import ModelType
...@@ -46,7 +47,6 @@ def save_checkpoint(queue, args): ...@@ -46,7 +47,6 @@ def save_checkpoint(queue, args):
md = queue_get() md = queue_get()
if args.target_tensor_parallel_size is None: if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'): if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size args.target_tensor_parallel_size = md.previous_tensor_parallel_size
...@@ -102,7 +102,10 @@ def save_checkpoint(queue, args): ...@@ -102,7 +102,10 @@ def save_checkpoint(queue, args):
if md.model_type == 'BERT' and not md.bert_binary_head: if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head') sys.argv.append('--bert-no-binary-head')
set_global_variables()
margs = parse_args()
validate_args(margs)
set_global_variables(margs)
# margs = megatron args # margs = megatron args
margs = get_args() margs = get_args()
...@@ -157,6 +160,7 @@ def save_checkpoint(queue, args): ...@@ -157,6 +160,7 @@ def save_checkpoint(queue, args):
else: else:
print("Original vocab size not specified, leaving embedding table as-is. " print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.") "If you've changed the tensor parallel size this could cause problems.")
margs.padded_vocab_size = orig_word_embed.shape[0]
full_word_embed = orig_word_embed full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes # Split into new tensor model parallel sizes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment