Commit 3dcbaec9 authored by mohammad's avatar mohammad
Browse files

added flag so we dont calculate params norm all the time

parent 929c780c
...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser) parser = _add_realm_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -273,6 +274,15 @@ def _add_network_size_args(parser): ...@@ -273,6 +274,15 @@ def _add_network_size_args(parser):
return parser return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
return parser
def _add_regularization_args(parser): def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization') group = parser.add_argument_group(title='regularization')
......
...@@ -828,7 +828,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -828,7 +828,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging. # Logging.
loss_scale = optimizer.get_loss_scale().item() loss_scale = optimizer.get_loss_scale().item()
params_norm = calc_params_l2_norm(model) params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
......
...@@ -187,7 +187,9 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -187,7 +187,9 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration += 1 iteration += 1
# Logging. # Logging.
params_norm = calc_params_l2_norm(model) params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum, report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, iteration,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment