Unverified Commit c7120163 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Further enhance Classification Reference (#4444)

* Adding ExponentialLR and LinearLR

* Fix arg type of --lr-warmup-decay

* Adding support of Zero gamma BN and SGD with nesterov.

* Fix --lr-warmup-decay for video_classification.

* Update bn_reinit

* Fix pre-existing bug on num_classes of model

* Remove zero gamma.

* Use fstrings.
parent 16405ac2
...@@ -186,17 +186,19 @@ def main(args): ...@@ -186,17 +186,19 @@ def main(args):
sampler=test_sampler, num_workers=args.workers, pin_memory=True) sampler=test_sampler, num_workers=args.workers, pin_memory=True)
print("Creating model") print("Creating model")
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
opt_name = args.opt.lower() opt_name = args.opt.lower()
if opt_name == 'sgd': if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name)
elif opt_name == 'rmsprop': elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
...@@ -214,15 +216,25 @@ def main(args): ...@@ -214,15 +216,25 @@ def main(args):
elif args.lr_scheduler == 'cosineannealinglr': elif args.lr_scheduler == 'cosineannealinglr':
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=args.epochs - args.lr_warmup_epochs) T_max=args.epochs - args.lr_warmup_epochs)
elif args.lr_scheduler == 'exponentiallr':
main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma)
else: else:
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR " raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR "
"are supported.".format(args.lr_scheduler)) "are supported.".format(args.lr_scheduler))
if args.lr_warmup_epochs > 0: if args.lr_warmup_epochs > 0:
if args.lr_warmup_method == 'linear':
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
total_iters=args.lr_warmup_epochs)
elif args.lr_warmup_method == 'constant':
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
total_iters=args.lr_warmup_epochs)
else:
raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant "
"are supported.")
lr_scheduler = torch.optim.lr_scheduler.SequentialLR( lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, optimizer,
schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, schedulers=[warmup_lr_scheduler, main_lr_scheduler],
total_iters=args.lr_warmup_epochs), main_lr_scheduler],
milestones=[args.lr_warmup_epochs] milestones=[args.lr_warmup_epochs]
) )
else: else:
...@@ -307,7 +319,9 @@ def get_args_parser(add_help=True): ...@@ -307,7 +319,9 @@ def get_args_parser(add_help=True):
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)') parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--lr-warmup-method', default="constant", type=str,
help='the warmup method (default: constant)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
......
...@@ -220,7 +220,7 @@ def get_args_parser(add_help=True): ...@@ -220,7 +220,7 @@ def get_args_parser(add_help=True):
dest='weight_decay') dest='weight_decay')
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr') parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
......
...@@ -296,7 +296,7 @@ def parse_args(): ...@@ -296,7 +296,7 @@ def parse_args():
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)') parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)')
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr') parser.add_argument('--lr-warmup-decay', default=0.001, type=float, help='the decay for lr')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--resume', default='', help='resume from checkpoint')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment