Commit a1d04b79 authored by Jared Casper's avatar Jared Casper
Browse files

Updating public repo with latest changes.

parent 93ab4bea
......@@ -35,7 +35,34 @@ def print_rank_0(message):
print(message, flush=True)
def print_args(args):
def enable_adlr_autoresume(args):
print_rank_0('enabling autoresume ...')
import sys
sys.path.append(os.environ.get('SUBMIT_SCRIPTS','.'))
try:
from userlib.auto_resume import AutoResume
except:
print_rank_0('ADLR autoresume is not available, exiting ...')
exit(0)
args.AutoResume = AutoResume
args.AutoResume.init()
def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args):
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if args.AutoResume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
args.AutoResume.request_resume()
print_rank_0(">>> training terminated. Returning")
exit(0)
def print_args(args, writer=None):
"""Print arguments."""
print('arguments:', flush=True)
......@@ -43,6 +70,8 @@ def print_args(args):
dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
if writer:
writer.add_text(arg, str(getattr(args, arg)))
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
......@@ -119,6 +148,16 @@ class Timers:
self.timers[name] = self.Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
......@@ -144,13 +183,13 @@ def report_memory(name):
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False):
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
if release:
d = 'release'
else:
d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()),
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank() if mp_rank is None else mp_rank),
'model_optim_rng.pt')
......@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src)
def merge_parallel_state_dicts(state_dicts):
temp_sd = {}
for sd in state_dicts:
for k, v in sd.items():
temp_sd[k].append()
pass
def merge_parallel_checkpoints(checkpoint_dir, model_parallel_size):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment