Unverified Commit cca16993 authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

support amp training for segmention models (#4994)



* support amp training for segmention models

* fix lint
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 58016b09
...@@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes): ...@@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes):
return confmat return confmat
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
header = f"Epoch: [{epoch}]" header = f"Epoch: [{epoch}]"
for image, target in metric_logger.log_every(data_loader, print_freq, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
output = model(image) with torch.cuda.amp.autocast(enabled=scaler is not None):
loss = criterion(output, target) output = model(image)
loss = criterion(output, target)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() if scaler is not None:
optimizer.step() scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
lr_scheduler.step() lr_scheduler.step()
...@@ -153,6 +159,8 @@ def main(args): ...@@ -153,6 +159,8 @@ def main(args):
params_to_optimize.append({"params": params, "lr": args.lr * 10}) params_to_optimize.append({"params": params, "lr": args.lr * 10})
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
iters_per_epoch = len(data_loader) iters_per_epoch = len(data_loader)
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9 optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
...@@ -186,6 +194,8 @@ def main(args): ...@@ -186,6 +194,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1 args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only: if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
...@@ -196,7 +206,7 @@ def main(args): ...@@ -196,7 +206,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat) print(confmat)
checkpoint = { checkpoint = {
...@@ -206,6 +216,8 @@ def main(args): ...@@ -206,6 +216,8 @@ def main(args):
"epoch": epoch, "epoch": epoch,
"args": args, "args": args,
} }
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
...@@ -269,6 +281,9 @@ def get_args_parser(add_help=True): ...@@ -269,6 +281,9 @@ def get_args_parser(add_help=True):
# Prototype models only # Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment