Commit aa32c937 authored by Vinh Nguyen's avatar Vinh Nguyen Committed by Francisco Massa
Browse files

adding mixed precision training with Apex (#972)

* adding mixed precision training with Apex

* fix APEX default optimization level

* adding python version check for apex

* fix LINT errors and raise exceptions if apex not available
parent a664dd0a
...@@ -2,6 +2,7 @@ from __future__ import print_function ...@@ -2,6 +2,7 @@ from __future__ import print_function
import datetime import datetime
import os import os
import time import time
import sys
import torch import torch
import torch.utils.data import torch.utils.data
...@@ -11,19 +12,31 @@ from torchvision import transforms ...@@ -11,19 +12,31 @@ from torchvision import transforms
import utils import utils
try:
from apex import amp
except ImportError:
amp = None
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch) header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
output = model(image) output = model(image)
loss = criterion(output, target) loss = criterion(output, target)
start_time = time.time()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step() optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
...@@ -31,6 +44,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri ...@@ -31,6 +44,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
def evaluate(model, criterion, data_loader, device): def evaluate(model, criterion, data_loader, device):
...@@ -68,6 +82,13 @@ def _get_cache_path(filepath): ...@@ -68,6 +82,13 @@ def _get_cache_path(filepath):
def main(args): def main(args):
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -161,6 +182,11 @@ def main(args): ...@@ -161,6 +182,11 @@ def main(args):
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
if args.apex:
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level
)
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
...@@ -177,7 +203,7 @@ def main(args): ...@@ -177,7 +203,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
lr_scheduler.step() lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir: if args.output_dir:
...@@ -249,6 +275,15 @@ def parse_args(): ...@@ -249,6 +275,15 @@ def parse_args():
action="store_true", action="store_true",
) )
# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
parser.add_argument('--apex-opt-level', default='O1', type=str,
help='For apex mixed precision training'
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)
# distributed training parameters # distributed training parameters
parser.add_argument('--world-size', default=1, type=int, parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes') help='number of distributed processes')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment