Commit 8acbbe25 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 7e810e41
...@@ -246,9 +246,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -246,9 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay is not None: if args.wd_incr_style == 'constant':
assert args.start_wd is None
assert args.end_wd is None
args.start_wd = args.weight_decay args.start_wd = args.weight_decay
args.end_wd = args.weight_decay args.end_wd = args.weight_decay
else:
assert args.start_wd is not None
assert args.end_wd is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
...@@ -399,11 +404,11 @@ def _add_regularization_args(parser): ...@@ -399,11 +404,11 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.') help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01, group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.') help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-wd', type=float, default=0.01, group.add_argument('--start-wd', type=float,
help='Initial weight decay coefficient for L2 regularization.') help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-wd', type=float, default=0.01, group.add_argument('--end-wd', type=float,
help='End of run weight decay coefficient for L2 regularization.') help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--wd-incr-style', type=str, default='linear', group.add_argument('--wd-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'], choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.') help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0, group.add_argument('--clip-grad', type=float, default=1.0,
......
...@@ -44,6 +44,7 @@ def get_param_groups(modules, ...@@ -44,6 +44,7 @@ def get_param_groups(modules,
if no_weight_decay_cond is not None: if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param) no_wd = no_weight_decay_cond(name, param)
else: else:
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1 no_wd = name.endswith(".bias") or len(param.shape) == 1
if scale_lr_cond is not None: if scale_lr_cond is not None:
......
...@@ -87,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -87,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add `lm-loss: value`. We also require that this function add
`batch generator` to the timers class. `batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments. to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It args_defaults: a dictionary from argument-name to argument-value. It
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment