Unverified Commit 403bded3 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Updated classification reference script to use torch.cuda.amp (#4547)

* Updated classification reference script to use torch.cuda.amp

* Assigned scaler to None if amp is False

* Fixed linter errors
parent 261cbf7e
...@@ -12,13 +12,10 @@ from torch import nn ...@@ -12,13 +12,10 @@ from torch import nn
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
try:
from apex import amp
except ImportError:
amp = None
def train_one_epoch(
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None): model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
...@@ -29,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri ...@@ -29,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
start_time = time.time() start_time = time.time()
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
output = model(image) output = model(image)
loss = criterion(output, target)
optimizer.zero_grad() optimizer.zero_grad()
if apex: if amp:
with amp.scale_loss(loss, optimizer) as scaled_loss: with torch.cuda.amp.autocast():
scaled_loss.backward() loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else: else:
loss = criterion(output, target)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -156,12 +156,6 @@ def load_data(traindir, valdir, args): ...@@ -156,12 +156,6 @@ def load_data(traindir, valdir, args):
def main(args): def main(args):
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -228,8 +222,7 @@ def main(args): ...@@ -228,8 +222,7 @@ def main(args):
else: else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
if args.apex: scaler = torch.cuda.amp.GradScaler() if args.amp else None
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
args.lr_scheduler = args.lr_scheduler.lower() args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "steplr": if args.lr_scheduler == "steplr":
...@@ -292,7 +285,9 @@ def main(args): ...@@ -292,7 +285,9 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
)
lr_scheduler.step() lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
if model_ema: if model_ema:
...@@ -385,15 +380,7 @@ def get_args_parser(add_help=True): ...@@ -385,15 +380,7 @@ def get_args_parser(add_help=True):
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
# Mixed precision training parameters # Mixed precision training parameters
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
parser.add_argument(
"--apex-opt-level",
default="O1",
type=str,
help="For apex mixed precision training"
"O0 for FP32 training, O1 for mixed precision training."
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
)
# distributed training parameters # distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment