"src/diffusers/pipelines/cogvideo/__init__.py" did not exist on "84cd9e8d01adb47f046b1ee449fc76a0c32dc4e2"
Unverified Commit ba299e8f authored by kbozas's avatar kbozas Committed by GitHub
Browse files

support amp training for video classification models (#5023)



* support amp training for video classification models

* Removed extra empty line and used scaler instead of args.amp as function argument

* apply formating to pass lint tests
Co-authored-by: default avatarKonstantinos Bozas <kbz@kbz-mbp.broadband>
parent fe4ba309
......@@ -12,19 +12,13 @@ from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
try:
from apex import amp
except ImportError:
amp = None
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
......@@ -34,13 +28,16 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
for video, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
video, target = video.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(video)
loss = criterion(output, target)
optimizer.zero_grad()
if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
......@@ -101,11 +98,6 @@ def collate_fn(batch):
def main(args):
if args.weights and PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if args.apex and amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -224,9 +216,7 @@ def main(args):
lr = args.lr * args.world_size
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs
......@@ -267,6 +257,8 @@ def main(args):
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
......@@ -277,9 +269,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex
)
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
......@@ -289,6 +279,8 @@ def main(args):
"epoch": epoch,
"args": args,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
......@@ -363,17 +355,6 @@ def parse_args():
action="store_true",
)
# Mixed precision training parameters
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
parser.add_argument(
"--apex-opt-level",
default="O1",
type=str,
help="For apex mixed precision training"
"O0 for FP32 training, O1 for mixed precision training."
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
)
# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
......@@ -381,6 +362,9 @@ def parse_args():
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
# Mixed precision training parameters
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
args = parser.parse_args()
return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment