Commit a1d04b79 authored by Jared Casper's avatar Jared Casper
Browse files

Updating public repo with latest changes.

parent 93ab4bea
...@@ -35,7 +35,34 @@ def print_rank_0(message): ...@@ -35,7 +35,34 @@ def print_rank_0(message):
print(message, flush=True) print(message, flush=True)
def print_args(args): def enable_adlr_autoresume(args):
print_rank_0('enabling autoresume ...')
import sys
sys.path.append(os.environ.get('SUBMIT_SCRIPTS','.'))
try:
from userlib.auto_resume import AutoResume
except:
print_rank_0('ADLR autoresume is not available, exiting ...')
exit(0)
args.AutoResume = AutoResume
args.AutoResume.init()
def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args):
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if args.AutoResume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
args.AutoResume.request_resume()
print_rank_0(">>> training terminated. Returning")
exit(0)
def print_args(args, writer=None):
"""Print arguments.""" """Print arguments."""
print('arguments:', flush=True) print('arguments:', flush=True)
...@@ -43,6 +70,8 @@ def print_args(args): ...@@ -43,6 +70,8 @@ def print_args(args):
dots = '.' * (29 - len(arg)) dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
if writer:
writer.add_text(arg, str(getattr(args, arg)))
def print_params_min_max_norm(optimizer, iteration): def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters.""" """Print min, max, and norm of all parameters."""
...@@ -119,6 +148,16 @@ class Timers: ...@@ -119,6 +148,16 @@ class Timers:
self.timers[name] = self.Timer(name) self.timers[name] = self.Timer(name)
return self.timers[name] return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
def log(self, names, normalizer=1.0, reset=True): def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers.""" """Log a group of timers."""
assert normalizer > 0.0 assert normalizer > 0.0
...@@ -144,13 +183,13 @@ def report_memory(name): ...@@ -144,13 +183,13 @@ def report_memory(name):
torch.cuda.max_memory_cached()/ mega_bytes) torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string) print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False): def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None):
if release: if release:
d = 'release' d = 'release'
else: else:
d = 'iter_{:07d}'.format(iteration) d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d, return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()), 'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank() if mp_rank is None else mp_rank),
'model_optim_rng.pt') 'model_optim_rng.pt')
...@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False): ...@@ -353,3 +392,14 @@ def move_weights(our, oai, dst2src=False):
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src) load_transformer_layer(our_layer, oai_layer, dst2src)
def merge_parallel_state_dicts(state_dicts):
temp_sd = {}
for sd in state_dicts:
for k, v in sd.items():
temp_sd[k].append()
pass
def merge_parallel_checkpoints(checkpoint_dir, model_parallel_size):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment