Unverified Commit 12fd3a62 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Added Exponential Moving Average support to classification reference script (#4381)

* Added Exponential Moving Average support to classification reference script

* Addressed review comments

* Updated model argument
parent c50d0fcc
...@@ -17,7 +17,8 @@ except ImportError: ...@@ -17,7 +17,8 @@ except ImportError:
amp = None amp = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False): def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
print_freq, apex=False, model_ema=None):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
...@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri ...@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
if model_ema:
model_ema.update_parameters(model)
def evaluate(model, criterion, data_loader, device, print_freq=100):
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = f'Test: {log_suffix}'
with torch.no_grad(): with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True) image = image.to(device, non_blocking=True)
...@@ -199,12 +203,18 @@ def main(args): ...@@ -199,12 +203,18 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module model_without_ddp = model.module
model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1 args.start_epoch = checkpoint['epoch'] + 1
if model_ema:
model_ema.load_state_dict(checkpoint['model_ema'])
if args.test_only: if args.test_only:
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
...@@ -215,9 +225,11 @@ def main(args): ...@@ -215,9 +225,11 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
lr_scheduler.step() lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
if args.output_dir: if args.output_dir:
checkpoint = { checkpoint = {
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
...@@ -225,6 +237,8 @@ def main(args): ...@@ -225,6 +237,8 @@ def main(args):
'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch, 'epoch': epoch,
'args': args} 'args': args}
if model_ema:
checkpoint['model_ema'] = model_ema.state_dict()
utils.save_on_master( utils.save_on_master(
checkpoint, checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
...@@ -306,6 +320,12 @@ def get_args_parser(add_help=True): ...@@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--world-size', default=1, type=int, parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes') help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument(
'--model-ema', action='store_true',
help='enable tracking Exponential Moving Average of model parameters')
parser.add_argument(
'--model-ema-decay', type=float, default=0.99,
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')
return parser return parser
......
...@@ -161,6 +161,18 @@ class MetricLogger(object): ...@@ -161,6 +161,18 @@ class MetricLogger(object):
print('{} Total time: {}'.format(header, total_time_str)) print('{} Total time: {}'.format(header, total_time_str))
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device='cpu'):
ema_avg = (lambda avg_model_param, model_param, num_averaged:
decay * avg_model_param + (1 - decay) * model_param)
super().__init__(model, device, ema_avg)
def accuracy(output, target, topk=(1,)): def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k""" """Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad(): with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment