Unverified Commit 3925946f authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

add set_weight_decay to support custom weight decay setting (#5671)



* add set_weight_decay

* Update _utils.py

* refactor code

* fix import

* add set_weight_decay

* fix lint

* fix lint

* replace split_normalization_params with set_weight_decay

* simplfy the code

* refactor code

* refactor code

* fix lint

* remove unused

* Update test_ops.py

* Update train.py

* Update _utils.py

* Update train.py

* add set_weight_decay

* add set_weight_decay

* Update _utils.py

* Update test_ops.py

* Change `--transformer-weight-decay` to `--transformer-embedding-decay`
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d59398b5
......@@ -229,12 +229,18 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if args.norm_weight_decay is None:
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
custom_keys_weight_decay = []
if args.bias_weight_decay is not None:
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
if args.transformer_embedding_decay is not None:
for key in ["class_token", "position_embedding", "relative_position_bias"]:
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
args.weight_decay,
norm_weight_decay=args.norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
)
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
......@@ -393,6 +399,18 @@ def get_args_parser(add_help=True):
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--bias-weight-decay",
default=None,
type=float,
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
)
parser.add_argument(
"--transformer-embedding-decay",
default=None,
type=float,
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
)
parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
)
......
......@@ -5,6 +5,7 @@ import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
......@@ -400,3 +401,65 @@ def reduce_across_processes(val):
dist.barrier()
dist.all_reduce(t)
return t
def set_weight_decay(
model: torch.nn.Module,
weight_decay: float,
norm_weight_decay: Optional[float] = None,
norm_classes: Optional[List[type]] = None,
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
):
if not norm_classes:
norm_classes = [
torch.nn.modules.batchnorm._BatchNorm,
torch.nn.LayerNorm,
torch.nn.GroupNorm,
torch.nn.modules.instancenorm._InstanceNorm,
torch.nn.LocalResponseNorm,
]
norm_classes = tuple(norm_classes)
params = {
"other": [],
"norm": [],
}
params_weight_decay = {
"other": weight_decay,
"norm": norm_weight_decay,
}
custom_keys = []
if custom_keys_weight_decay is not None:
for key, weight_decay in custom_keys_weight_decay:
params[key] = []
params_weight_decay[key] = weight_decay
custom_keys.append(key)
def _add_params(module, prefix=""):
for name, p in module.named_parameters(recurse=False):
if not p.requires_grad:
continue
is_custom_key = False
for key in custom_keys:
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
if key == target_name:
params[key].append(p)
is_custom_key = True
break
if not is_custom_key:
if norm_weight_decay is not None and isinstance(module, norm_classes):
params["norm"].append(p)
else:
params["other"].append(p)
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
_add_params(child_module, prefix=child_prefix)
_add_params(model)
param_groups = []
for key in params:
if len(params[key]) > 0:
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
return param_groups
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment