"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f48f9c250fe7f6027c7445f8be04bb4499583f58"
Unverified Commit 12d2c737 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fixes for PyTorch 1.1 (#919)

parent 86db394e
......@@ -147,7 +147,7 @@ def main(args):
model = torchvision.models.__dict__[args.model]()
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.utils.convert_sync_batchnorm(model)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_without_ddp = model
if args.distributed:
......@@ -177,8 +177,8 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
lr_scheduler.step()
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
......
......@@ -124,7 +124,7 @@ def main(args):
model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss)
model.to(device)
if args.distributed:
model = torch.nn.utils.convert_sync_batchnorm(model)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment