Commit 3dcbaec9 authored by mohammad's avatar mohammad
Browse files

added flag so we dont calculate params norm all the time

parent 929c780c
......@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -273,6 +274,15 @@ def _add_network_size_args(parser):
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')
......
......@@ -828,7 +828,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging.
loss_scale = optimizer.get_loss_scale().item()
params_norm = calc_params_l2_norm(model)
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
......
......@@ -187,7 +187,9 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration += 1
# Logging.
params_norm = calc_params_l2_norm(model)
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment