Unverified Commit 59ec1dfd authored by Hu Ye's avatar Hu Ye Committed by GitHub
Browse files

support amp training for detection models (#4933)



* support amp training

* support amp training

* support amp training

* Update references/detection/train.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update references/detection/engine.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* fix lint issues
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 4b20ac52
...@@ -9,7 +9,7 @@ from coco_eval import CocoEvaluator ...@@ -9,7 +9,7 @@ from coco_eval import CocoEvaluator
from coco_utils import get_coco_api_from_dataset from coco_utils import get_coco_api_from_dataset
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
...@@ -27,9 +27,8 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): ...@@ -27,9 +27,8 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
for images, targets in metric_logger.log_every(data_loader, print_freq, header): for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images) images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
with torch.cuda.amp.autocast(enabled=scaler is not None):
loss_dict = model(images, targets) loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values()) losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes # reduce losses over all GPUs for logging purposes
...@@ -44,6 +43,11 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): ...@@ -44,6 +43,11 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
sys.exit(1) sys.exit(1)
optimizer.zero_grad() optimizer.zero_grad()
if scaler is not None:
scaler.scale(losses).backward()
scaler.step(optimizer)
scaler.update()
else:
losses.backward() losses.backward()
optimizer.step() optimizer.step()
......
...@@ -144,6 +144,9 @@ def get_args_parser(add_help=True): ...@@ -144,6 +144,9 @@ def get_args_parser(add_help=True):
# Prototype models only # Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
return parser return parser
...@@ -209,6 +212,8 @@ def main(args): ...@@ -209,6 +212,8 @@ def main(args):
params = [p for p in model.parameters() if p.requires_grad] params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower() args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "multisteplr": if args.lr_scheduler == "multisteplr":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
...@@ -225,6 +230,8 @@ def main(args): ...@@ -225,6 +230,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1 args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only: if args.test_only:
evaluate(model, data_loader_test, device=device) evaluate(model, data_loader_test, device=device)
...@@ -235,7 +242,7 @@ def main(args): ...@@ -235,7 +242,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
lr_scheduler.step() lr_scheduler.step()
if args.output_dir: if args.output_dir:
checkpoint = { checkpoint = {
...@@ -245,6 +252,8 @@ def main(args): ...@@ -245,6 +252,8 @@ def main(args):
"args": args, "args": args,
"epoch": epoch, "epoch": epoch,
} }
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment