Unverified Commit ba299e8f authored by kbozas's avatar kbozas Committed by GitHub
Browse files

support amp training for video classification models (#5023)



* support amp training for video classification models

* Removed extra empty line and used scaler instead of args.amp as function argument

* apply formating to pass lint tests
Co-authored-by: default avatarKonstantinos Bozas <kbz@kbz-mbp.broadband>
parent fe4ba309
...@@ -12,19 +12,13 @@ from torch import nn ...@@ -12,19 +12,13 @@ from torch import nn
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
try:
from apex import amp
except ImportError:
amp = None
try: try:
from torchvision.prototype import models as PM from torchvision.prototype import models as PM
except ImportError: except ImportError:
PM = None PM = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
...@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi ...@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
for video, target in metric_logger.log_every(data_loader, print_freq, header): for video, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time() start_time = time.time()
video, target = video.to(device), target.to(device) video, target = video.to(device), target.to(device)
output = model(video) with torch.cuda.amp.autocast(enabled=scaler is not None):
loss = criterion(output, target) output = model(video)
loss = criterion(output, target)
optimizer.zero_grad() optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss: if scaler is not None:
scaled_loss.backward() scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else: else:
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = video.shape[0] batch_size = video.shape[0]
...@@ -101,11 +98,6 @@ def collate_fn(batch): ...@@ -101,11 +98,6 @@ def collate_fn(batch):
def main(args): def main(args):
if args.weights and PM is None: if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -224,9 +216,7 @@ def main(args): ...@@ -224,9 +216,7 @@ def main(args):
lr = args.lr * args.world_size lr = args.lr * args.world_size
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
# convert scheduler to be per iteration, not per epoch, for warmup that lasts # convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs # between different epochs
...@@ -267,6 +257,8 @@ def main(args): ...@@ -267,6 +257,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1 args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only: if args.test_only:
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
...@@ -277,9 +269,7 @@ def main(args): ...@@ -277,9 +269,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch( train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex
)
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir: if args.output_dir:
checkpoint = { checkpoint = {
...@@ -289,6 +279,8 @@ def main(args): ...@@ -289,6 +279,8 @@ def main(args):
"epoch": epoch, "epoch": epoch,
"args": args, "args": args,
} }
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
...@@ -363,17 +355,6 @@ def parse_args(): ...@@ -363,17 +355,6 @@ def parse_args():
action="store_true", action="store_true",
) )
# Mixed precision training parameters
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
parser.add_argument(
"--apex-opt-level",
default="O1",
type=str,
help="For apex mixed precision training"
"O0 for FP32 training, O1 for mixed precision training."
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
)
# distributed training parameters # distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
...@@ -381,6 +362,9 @@ def parse_args(): ...@@ -381,6 +362,9 @@ def parse_args():
# Prototype models only # Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
args = parser.parse_args() args = parser.parse_args()
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment