Unverified Commit 12fd3a62 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Added Exponential Moving Average support to classification reference script (#4381)

* Added Exponential Moving Average support to classification reference script

* Addressed review comments

* Updated model argument
parent c50d0fcc
......@@ -17,7 +17,8 @@ except ImportError:
amp = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
print_freq, apex=False, model_ema=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
......@@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
if model_ema:
model_ema.update_parameters(model)
def evaluate(model, criterion, data_loader, device, print_freq=100):
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
header = f'Test: {log_suffix}'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
......@@ -199,12 +203,18 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema:
model_ema.load_state_dict(checkpoint['model_ema'])
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
......@@ -215,9 +225,11 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
......@@ -225,6 +237,8 @@ def main(args):
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
if model_ema:
checkpoint['model_ema'] = model_ema.state_dict()
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
......@@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument(
'--model-ema', action='store_true',
help='enable tracking Exponential Moving Average of model parameters')
parser.add_argument(
'--model-ema-decay', type=float, default=0.99,
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')
return parser
......
......@@ -161,6 +161,18 @@ class MetricLogger(object):
print('{} Total time: {}'.format(header, total_time_str))
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device='cpu'):
ema_avg = (lambda avg_model_param, model_param, num_averaged:
decay * avg_model_param + (1 - decay) * model_param)
super().__init__(model, device, ema_avg)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment