"...ko/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8b18cd8e7f6be0cf2904dfec4285d4ba98c5586f"
Unverified Commit 3925946f authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

add set_weight_decay to support custom weight decay setting (#5671)



* add set_weight_decay

* Update _utils.py

* refactor code

* fix import

* add set_weight_decay

* fix lint

* fix lint

* replace split_normalization_params with set_weight_decay

* simplfy the code

* refactor code

* refactor code

* fix lint

* remove unused

* Update test_ops.py

* Update train.py

* Update _utils.py

* Update train.py

* add set_weight_decay

* add set_weight_decay

* Update _utils.py

* Update test_ops.py

* Change `--transformer-weight-decay` to `--transformer-embedding-decay`
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d59398b5
...@@ -229,12 +229,18 @@ def main(args): ...@@ -229,12 +229,18 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if args.norm_weight_decay is None: custom_keys_weight_decay = []
parameters = [p for p in model.parameters() if p.requires_grad] if args.bias_weight_decay is not None:
else: custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
param_groups = torchvision.ops._utils.split_normalization_params(model) if args.transformer_embedding_decay is not None:
wd_groups = [args.norm_weight_decay, args.weight_decay] for key in ["class_token", "position_embedding", "relative_position_bias"]:
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
args.weight_decay,
norm_weight_decay=args.norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
)
opt_name = args.opt.lower() opt_name = args.opt.lower()
if opt_name.startswith("sgd"): if opt_name.startswith("sgd"):
...@@ -393,6 +399,18 @@ def get_args_parser(add_help=True): ...@@ -393,6 +399,18 @@ def get_args_parser(add_help=True):
type=float, type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)", help="weight decay for Normalization layers (default: None, same value as --wd)",
) )
parser.add_argument(
"--bias-weight-decay",
default=None,
type=float,
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
)
parser.add_argument(
"--transformer-embedding-decay",
default=None,
type=float,
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
)
parser.add_argument( parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
) )
......
...@@ -5,6 +5,7 @@ import hashlib ...@@ -5,6 +5,7 @@ import hashlib
import os import os
import time import time
from collections import defaultdict, deque, OrderedDict from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -400,3 +401,65 @@ def reduce_across_processes(val): ...@@ -400,3 +401,65 @@ def reduce_across_processes(val):
dist.barrier() dist.barrier()
dist.all_reduce(t) dist.all_reduce(t)
return t return t
def set_weight_decay(
model: torch.nn.Module,
weight_decay: float,
norm_weight_decay: Optional[float] = None,
norm_classes: Optional[List[type]] = None,
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
):
if not norm_classes:
norm_classes = [
torch.nn.modules.batchnorm._BatchNorm,
torch.nn.LayerNorm,
torch.nn.GroupNorm,
torch.nn.modules.instancenorm._InstanceNorm,
torch.nn.LocalResponseNorm,
]
norm_classes = tuple(norm_classes)
params = {
"other": [],
"norm": [],
}
params_weight_decay = {
"other": weight_decay,
"norm": norm_weight_decay,
}
custom_keys = []
if custom_keys_weight_decay is not None:
for key, weight_decay in custom_keys_weight_decay:
params[key] = []
params_weight_decay[key] = weight_decay
custom_keys.append(key)
def _add_params(module, prefix=""):
for name, p in module.named_parameters(recurse=False):
if not p.requires_grad:
continue
is_custom_key = False
for key in custom_keys:
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
if key == target_name:
params[key].append(p)
is_custom_key = True
break
if not is_custom_key:
if norm_weight_decay is not None and isinstance(module, norm_classes):
params["norm"].append(p)
else:
params["other"].append(p)
for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
_add_params(child_module, prefix=child_prefix)
_add_params(model)
param_groups = []
for key in params:
if len(params[key]) > 0:
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
return param_groups
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment