Commit c187c2b1 authored by Vinh Nguyen's avatar Vinh Nguyen Committed by Francisco Massa
Browse files

Fix apex distributed training (#1124)

* adding mixed precision training with Apex

* fix APEX default optimization level

* adding python version check for apex

* fix LINT errors and raise exceptions if apex not available

* fixing apex distributed training

* fix throughput calculation: include forward pass

* remove torch.cuda.set_device(args.gpu) as it's already called in init_distributed_mode

* fix linter: new line

* move Apex initialization code back to the beginning of main

* move apex initialization to before lr_scheduler - for peace of mind. Though, doing apex initialization after lr_scheduler seems to work fine as well
parent 5d1372c0
...@@ -26,11 +26,11 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri ...@@ -26,11 +26,11 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
header = 'Epoch: [{}]'.format(epoch) header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
start_time = time.time()
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
output = model(image) output = model(image)
loss = criterion(output, target) loss = criterion(output, target)
start_time = time.time()
optimizer.zero_grad() optimizer.zero_grad()
if apex: if apex:
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
...@@ -170,23 +170,23 @@ def main(args): ...@@ -170,23 +170,23 @@ def main(args):
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
if args.apex: if args.apex:
model, optimizer = amp.initialize(model, optimizer, model, optimizer = amp.initialize(model, optimizer,
opt_level=args.apex_opt_level opt_level=args.apex_opt_level
) )
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment