Commit 8acbbe25 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 7e810e41
......@@ -246,9 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay is not None:
if args.wd_incr_style == 'constant':
assert args.start_wd is None
assert args.end_wd is None
args.start_wd = args.weight_decay
args.end_wd = args.weight_decay
else:
assert args.start_wd is not None
assert args.end_wd is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
......@@ -399,11 +404,11 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-wd', type=float, default=0.01,
group.add_argument('--start-wd', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-wd', type=float, default=0.01,
group.add_argument('--end-wd', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--wd-incr-style', type=str, default='linear',
group.add_argument('--wd-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
......
......@@ -44,6 +44,7 @@ def get_param_groups(modules,
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None:
......
......@@ -87,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment