"vscode:/vscode.git/clone" did not exist on "be2070991f1b916977020c45ecdfec225de21862"
Commit a54978bb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging' into 'master'

Updating public repo with latest changes.

See merge request ADLR/megatron-lm!1
parents 93ab4bea 5d402eb4
......@@ -35,7 +35,34 @@ def print_rank_0(message):
print(message, flush=True)
def print_args(args):
def enable_adlr_autoresume(args):
print_rank_0('enabling autoresume ...')
import sys
sys.path.append(os.environ.get('SUBMIT_SCRIPTS','.'))
try:
from userlib.auto_resume import AutoResume
except:
print_rank_0('ADLR autoresume is not available, exiting ...')
exit(0)
args.AutoResume = AutoResume
args.AutoResume.init()
def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args):
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if args.AutoResume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
args.AutoResume.request_resume()
print_rank_0(">>> training terminated. Returning")
exit(0)
def print_args(args, writer=None):
"""Print arguments."""
print('arguments:', flush=True)
......@@ -43,6 +70,8 @@ def print_args(args):
dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
if writer:
writer.add_text(arg, str(getattr(args, arg)))
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
......@@ -119,6 +148,16 @@ class Timers:
self.timers[name] = self.Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
......@@ -144,13 +183,13 @@ def report_memory(name):
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False):
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
if release:
d = 'release'
else:
d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()),
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank() if mp_rank is None else mp_rank),
'model_optim_rng.pt')
......@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src)
def merge_parallel_state_dicts(state_dicts):
temp_sd = {}
for sd in state_dicts:
for k, v in sd.items():
temp_sd[k].append()
pass
def merge_parallel_checkpoints(checkpoint_dir, model_parallel_size):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment