Unverified Commit 69ba6789 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

Add DeepLabV3 + ResNeSt-269 (#263)

parent 17be9e16
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import print_function
import os
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import encoding
from encoding.nn import LabelSmoothing, NLLMultiLabelSmooth
from encoding.utils import (accuracy, AverageMeter, MixUpWrapper, LR_Scheduler)
class Options():
def __init__(self):
# data settings
parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--dataset', type=str, default='cifar10',
help='training dataset (default: cifar10)')
parser.add_argument('--base-size', type=int, default=None,
help='base image size')
parser.add_argument('--crop-size', type=int, default=224,
help='crop image size')
parser.add_argument('--label-smoothing', type=float, default=0.0,
help='label-smoothing (default eta: 0.0)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup (default eta: 0.0)')
parser.add_argument('--rand-aug', action='store_true',
default=False, help='rectify convolution')
# model params
parser.add_argument('--model', type=str, default='densenet',
help='network model type (default: densenet)')
parser.add_argument('--pretrained', action='store_true',
default=False, help='load pretrianed mode')
parser.add_argument('--rectify', action='store_true',
default=False, help='rectify convolution')
parser.add_argument('--rectify-avg', action='store_true',
default=False, help='rectify convolution')
parser.add_argument('--last-gamma', action='store_true', default=False,
help='whether to init gamma of the last BN layer in \
each bottleneck to 0 (default: False)')
parser.add_argument('--dropblock-prob', type=float, default=0,
help='DropBlock prob. default is 0.')
parser.add_argument('--final-drop', type=float, default=0,
help='final dropout prob. default is 0.')
# training hyper params
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
help='batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=120, metavar='N',
help='number of epochs to train (default: 600)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='the epoch number to start (default: 1)')
parser.add_argument('--workers', type=int, default=32,
metavar='N', help='dataloader threads')
# optimizer
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)')
parser.add_argument('--lr-scheduler', type=str, default='cos',
help='learning rate scheduler (default: cos)')
parser.add_argument('--warmup-epochs', type=int, default=0,
help='number of warmup epochs (default: 0)')
parser.add_argument('--lr-step', type=int, default=40, metavar='LR',
help='learning rate step (default: 40)')
parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
metavar ='M', help='SGD weight decay (default: 1e-4)')
parser.add_argument('--no-bn-wd', action='store_true',
default=False, help='no bias decay')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true',
default=False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# checking point
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--checkname', type=str, default='default',
help='set the checkpoint name')
# evaluation option
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating')
parser.add_argument('--export', type=str, default=None,
help='put the path to resuming file if needed')
self.parser = parser
def parse(self):
args = self.parser.parse_args()
return args
# global variable
best_pred = 0.0
acclist_train = []
acclist_val = []
def main():
# init the args
global best_pred, acclist_train, acclist_val
args = Options().parse()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
# init dataloader
transform_train, transform_val = encoding.transforms.get_transform(
args.dataset, args.base_size, args.crop_size, args.rand_aug)
trainset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
transform=transform_train, train=True, download=True)
valset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
transform=transform_val, train=False, download=True)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
valset, batch_size=args.test_batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# init the model
model_kwargs = {}
if args.pretrained:
model_kwargs['pretrained'] = True
if args.final_drop > 0.0:
model_kwargs['final_drop'] = args.final_drop
if args.dropblock_prob > 0.0:
model_kwargs['dropblock_prob'] = args.dropblock_prob
if args.last_gamma:
model_kwargs['last_gamma'] = True
if args.rectify:
model_kwargs['rectified_conv'] = True
model_kwargs['rectify_avg'] = args.rectify_avg
model = encoding.models.get_model(args.model, **model_kwargs)
if args.dropblock_prob > 0.0:
from functools import partial
from encoding.nn import reset_dropblock
nr_iters = (args.epochs - 2 * args.warmup_epochs) * len(train_loader)
apply_drop_prob = partial(reset_dropblock, args.warmup_epochs*len(train_loader),
nr_iters, 0.0, args.dropblock_prob)
model.apply(apply_drop_prob)
print(model)
# criterion and optimizer
if args.mixup > 0:
train_loader = MixUpWrapper(args.mixup, 1000, train_loader,
list(range(torch.cuda.device_count())))
criterion = NLLMultiLabelSmooth(args.label_smoothing)
elif args.label_smoothing > 0.0:
criterion = LabelSmoothing(args.label_smoothing)
else:
criterion = nn.CrossEntropyLoss()
if args.no_bn_wd:
parameters = model.named_parameters()
param_dict = {}
for k, v in parameters:
param_dict[k] = v
bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)]
rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)]
print(" Weight decay NOT applied to BN parameters ")
print(f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}')
optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0 },
{'params': rest_params, 'weight_decay': args.weight_decay}],
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
if args.cuda:
model.cuda()
criterion.cuda()
# Please use CUDA_VISIBLE_DEVICES to control the number of gpus
model = nn.DataParallel(model)
if args.resume is not None:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch'] + 1 if args.start_epoch == 1 else args.start_epoch
best_pred = checkpoint['best_pred']
acclist_train = checkpoint['acclist_train']
acclist_val = checkpoint['acclist_val']
model.module.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
raise RuntimeError ("=> no resume checkpoint found at '{}'".\
format(args.resume))
scheduler = LR_Scheduler(args.lr_scheduler,
base_lr=args.lr,
num_epochs=args.epochs,
iters_per_epoch=len(train_loader),
warmup_epochs=args.warmup_epochs,
lr_step=args.lr_step)
def train(epoch):
model.train()
losses = AverageMeter()
top1 = AverageMeter()
global best_pred, acclist_train
tbar = tqdm(train_loader, desc='\r')
for batch_idx, (data, target) in enumerate(tbar):
scheduler(optimizer, batch_idx, epoch, best_pred)
#criterion.update(batch_idx, epoch)
if args.cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
acc1 = accuracy(output, target, topk=(1,))
top1.update(acc1[0], data.size(0))
losses.update(loss.item(), data.size(0))
tbar.set_description('\rLoss: %.3f | Top1: %.3f'%(losses.avg, top1.avg))
acclist_train += [top1.avg]
def validate(epoch):
model.eval()
top1 = AverageMeter()
top5 = AverageMeter()
global best_pred, acclist_train, acclist_val
is_best = False
tbar = tqdm(val_loader, desc='\r')
for batch_idx, (data, target) in enumerate(tbar):
if args.cuda:
data, target = data.cuda(), target.cuda()
with torch.no_grad():
output = model(data)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], data.size(0))
top5.update(acc5[0], data.size(0))
tbar.set_description('Top1: %.3f | Top5: %.3f'%(top1.avg, top5.avg))
if args.eval:
print('Top1 Acc: %.3f | Top5 Acc: %.3f '%(top1.avg, top5.avg))
return
# save checkpoint
acclist_val += [top1.avg]
if top1.avg > best_pred:
best_pred = top1.avg
is_best = True
encoding.utils.save_checkpoint({
'args': args,
'epoch': epoch,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'best_pred': best_pred,
'acclist_train':acclist_train,
'acclist_val':acclist_val,
}, args=args, is_best=is_best)
if args.export:
torch.save(model.module.state_dict(), args.export + '.pth')
return
if args.eval:
validate(args.start_epoch)
return
for epoch in range(args.start_epoch, args.epochs):
train(epoch)
validate(epoch)
validate(epoch)
if __name__ == "__main__":
main()
# baseline
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 120 --checkname resnet50_check --lr 0.025 --batch-size 64
# rectify
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 120 --checkname resnet50_rt --lr 0.1 --batch-size 256 --rectify
# warmup
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 120 --checkname resnet50_rt_warm --lr 0.1 --batch-size 256 --warmup-epochs 5 --rectify
# no-bn-wd
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 120 --checkname resnet50_rt_nobnwd_warm --lr 0.1 --batch-size 256 --no-bn-wd --warmup-epochs 5 --rectify
# LS
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 120 --checkname resnet50_rt_ls --lr 0.1 --batch-size 256 --label-smoothing 0.1 --rectify
# Mixup + LS
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 200 --checkname resnet50_rt_ls_mixup --lr 0.1 --batch-size 256 --label-smoothing 0.1 --mixup 0.2 --rectify
# last-gamma
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 120 --checkname resnet50_rt_gamma --lr 0.1 --batch-size 256 --last-gamma --rectify
# BoTs
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 200 --checkname resnet50_rt_bots --lr 0.1 --batch-size 256 --label-smoothing 0.1 --mixup 0.2 --last-gamma --no-bn-wd --warmup-epochs 5 --rectify
# resnet50d
python train_dist.py --dataset imagenet --model resnet50d --lr-scheduler cos --epochs 200 --checkname resnet50d_rt_bots --lr 0.1 --batch-size 256 --label-smoothing 0.1 --mixup 0.2 --last-gamma --no-bn-wd --warmup-epochs 5 --rectify
# dropblock
python train_dist.py --dataset imagenet --model resnet50 --lr-scheduler cos --epochs 200 --checkname --label-smoothing 0.1 --mixup 0.2 --lr 0.1 --batch-size 256 --label-smoothing 0.1 --mixup 0.2 --dropblock-prob 0.1 --rectify
# resnest50
python train_dist.py --dataset imagenet --model resnest50 --lr-scheduler cos --epochs 270 --checkname resnest50_rt_bots --lr 0.1 --batch-size 256 --label-smoothing 0.1 --mixup 0.2 --last-gamma --no-bn-wd --warmup-epochs 5 --dropblock-prob 0.1 --rectify
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import argparse
import torch
from thop import profile, clever_format
import encoding
def get_args():
# data settings
parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--crop-size', type=int, default=224,
help='crop image size')
# model params
parser.add_argument('--model', type=str, default='densenet',
help='network model type (default: densenet)')
parser.add_argument('--rectify', action='store_true',
default=False, help='rectify convolution')
parser.add_argument('--rectify-avg', action='store_true',
default=False, help='rectify convolution')
# checking point
parser = parser
args = parser.parse_args()
return args
def main():
args = get_args()
model_kwargs = {}
if args.rectify:
model_kwargs['rectified_conv'] = True
model_kwargs['rectify_avg'] = args.rectify_avg
model = encoding.models.get_model(args.model, **model_kwargs)
print(model)
dummy_images = torch.rand(1, 3, args.crop_size, args.crop_size)
#count_ops(model, dummy_images, verbose=False)
macs, params = profile(model, inputs=(dummy_images, ))
macs, params = clever_format([macs, params], "%.3f")
print(f"macs: {macs}, params: {params}")
if __name__ == '__main__':
main()
...@@ -20,14 +20,14 @@ from torch.nn.parallel import DistributedDataParallel ...@@ -20,14 +20,14 @@ from torch.nn.parallel import DistributedDataParallel
import encoding import encoding
from encoding.nn import LabelSmoothing, NLLMultiLabelSmooth from encoding.nn import LabelSmoothing, NLLMultiLabelSmooth
from encoding.utils import (accuracy, AverageMeter, MixUpWrapper, LR_Scheduler) from encoding.utils import (accuracy, AverageMeter, MixUpWrapper, LR_Scheduler, torch_dist_sum)
class Options(): class Options():
def __init__(self): def __init__(self):
# data settings # data settings
parser = argparse.ArgumentParser(description='Deep Encoding') parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--dataset', type=str, default='cifar10', parser.add_argument('--dataset', type=str, default='imagenet',
help='training dataset (default: cifar10)') help='training dataset (default: imagenet)')
parser.add_argument('--base-size', type=int, default=None, parser.add_argument('--base-size', type=int, default=None,
help='base image size') help='base image size')
parser.add_argument('--crop-size', type=int, default=224, parser.add_argument('--crop-size', type=int, default=224,
...@@ -95,6 +95,11 @@ class Options(): ...@@ -95,6 +95,11 @@ class Options():
help='url used to set up distributed training') help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str, parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend') help='distributed backend')
# evaluation option
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating')
parser.add_argument('--export', type=str, default=None,
help='put the path to resuming file if needed')
self.parser = parser self.parser = parser
def parse(self): def parse(self):
...@@ -113,21 +118,6 @@ best_pred = 0.0 ...@@ -113,21 +118,6 @@ best_pred = 0.0
acclist_train = [] acclist_train = []
acclist_val = [] acclist_val = []
def torch_dist_sum(gpu, *args):
process_group = torch.distributed.group.WORLD
tensor_args = []
pending_res = []
for arg in args:
if isinstance(arg, torch.Tensor):
tensor_arg = arg.clone().reshape(1).detach().cuda(gpu)
else:
tensor_arg = torch.tensor(arg).reshape(1).cuda(gpu)
tensor_args.append(tensor_arg)
pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True))
for res in pending_res:
res.wait()
return tensor_args
def main_worker(gpu, ngpus_per_node, args): def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu args.gpu = gpu
args.rank = args.rank * ngpus_per_node + gpu args.rank = args.rank * ngpus_per_node + gpu
...@@ -296,6 +286,14 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -296,6 +286,14 @@ def main_worker(gpu, ngpus_per_node, args):
# sum all # sum all
sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count) sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count)
if args.eval:
if args.gpu == 0:
top1_acc = sum(sum1) / sum(cnt1)
top5_acc = sum(sum5) / sum(cnt5)
print('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc))
return
if args.gpu == 0: if args.gpu == 0:
top1_acc = sum(sum1) / sum(cnt1) top1_acc = sum(sum1) / sum(cnt1)
top5_acc = sum(sum5) / sum(cnt5) top5_acc = sum(sum5) / sum(cnt5)
...@@ -315,16 +313,33 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -315,16 +313,33 @@ def main_worker(gpu, ngpus_per_node, args):
'acclist_val':acclist_val, 'acclist_val':acclist_val,
}, args=args, is_best=is_best) }, args=args, is_best=is_best)
if args.export:
if args.gpu == 0:
torch.save(model.module.state_dict(), args.export + '.pth')
return
if args.eval:
validate(args.start_epoch)
return
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
tic = time.time() tic = time.time()
train(epoch) train(epoch)
if epoch % 10 == 0: if epoch % 10 == 0:# or epoch == args.epochs-1:
validate(epoch) validate(epoch)
elapsed = time.time() - tic elapsed = time.time() - tic
if args.gpu == 0: if args.gpu == 0:
print(f'Epoch: {epoch}, Time cost: {elapsed}') print(f'Epoch: {epoch}, Time cost: {elapsed}')
validate(epoch) if args.gpu == 0:
encoding.utils.save_checkpoint({
'epoch': args.epochs-1,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'best_pred': best_pred,
'acclist_train':acclist_train,
'acclist_val':acclist_val,
}, args=args, is_best=False)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -21,8 +21,8 @@ class Options(): ...@@ -21,8 +21,8 @@ class Options():
def __init__(self): def __init__(self):
# data settings # data settings
parser = argparse.ArgumentParser(description='Deep Encoding') parser = argparse.ArgumentParser(description='Deep Encoding')
parser.add_argument('--dataset', type=str, default='cifar10', parser.add_argument('--dataset', type=str, default='imagenet',
help='training dataset (default: cifar10)') help='training dataset (default: imagenet)')
parser.add_argument('--base-size', type=int, default=None, parser.add_argument('--base-size', type=int, default=None,
help='base image size') help='base image size')
parser.add_argument('--crop-size', type=int, default=224, parser.add_argument('--crop-size', type=int, default=224,
......
...@@ -87,13 +87,6 @@ class Options(): ...@@ -87,13 +87,6 @@ class Options():
print(args) print(args)
return args return args
@torch.no_grad()
def reset_bn_statistics(m):
if isinstance(m, torch.nn.BatchNorm2d):
#print(m)
m.momentum = 0.0
m.reset_running_stats()
def test(args): def test(args):
# output folder # output folder
outdir = 'outdir' outdir = 'outdir'
...@@ -120,16 +113,18 @@ def test(args): ...@@ -120,16 +113,18 @@ def test(args):
drop_last=False, shuffle=False, drop_last=False, shuffle=False,
collate_fn=test_batchify_fn, **loader_kwargs) collate_fn=test_batchify_fn, **loader_kwargs)
# model # model
pretrained = args.resume is None and args.verify is None
if args.model_zoo is not None: if args.model_zoo is not None:
model = get_model(args.model_zoo, pretrained=True) model = get_model(args.model_zoo, pretrained=pretrained)
#model.base_size = args.base_size model.base_size = args.base_size
#model.crop_size = args.crop_size model.crop_size = args.crop_size
else: else:
model = get_segmentation_model(args.model, dataset=args.dataset, model = get_segmentation_model(args.model, dataset=args.dataset,
backbone=args.backbone, aux = args.aux, backbone=args.backbone, aux = args.aux,
se_loss=args.se_loss, se_loss=args.se_loss,
norm_layer=torch.nn.BatchNorm2d if args.acc_bn else SyncBatchNorm, norm_layer=torch.nn.BatchNorm2d if args.acc_bn else SyncBatchNorm,
base_size=args.base_size, crop_size=args.crop_size) base_size=args.base_size, crop_size=args.crop_size)
# resuming checkpoint # resuming checkpoint
if args.verify is not None and os.path.isfile(args.verify): if args.verify is not None and os.path.isfile(args.verify):
print("=> loading checkpoint '{}'".format(args.verify)) print("=> loading checkpoint '{}'".format(args.verify))
...@@ -139,27 +134,21 @@ def test(args): ...@@ -139,27 +134,21 @@ def test(args):
# strict=False, so that it is compatible with old pytorch saved models # strict=False, so that it is compatible with old pytorch saved models
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else: elif not pretrained:
raise RuntimeError ("=> no checkpoint found") raise RuntimeError ("=> no checkpoint found")
print(model) print(model)
# accumulate bn statistics
if args.acc_bn: if args.acc_bn:
print('Reseting BN statistics') from encoding.utils.precise_bn import update_bn_stats
model.apply(reset_bn_statistics)
data_kwargs = {'transform': input_transform, 'base_size': args.base_size, data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
'crop_size': args.crop_size} 'crop_size': args.crop_size}
trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs) trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
trainloader = data.DataLoader(trainset, batch_size=args.batch_size, trainloader = data.DataLoader(ReturnFirstClosure(trainset), batch_size=args.batch_size,
drop_last=True, shuffle=True, **loader_kwargs) drop_last=True, shuffle=True, **loader_kwargs)
tbar = tqdm(trainloader) print('Reseting BN statistics')
model.train() #model.apply(reset_bn_statistics)
model.cuda() model.cuda()
for i, (image, dst) in enumerate(tbar): update_bn_stats(model, trainloader)
image = image.cuda()
with torch.no_grad():
outputs = model(image)
if i > 1000: break
if args.export: if args.export:
torch.save(model.state_dict(), args.export + '.pth') torch.save(model.state_dict(), args.export + '.pth')
...@@ -191,6 +180,17 @@ def test(args): ...@@ -191,6 +180,17 @@ def test(args):
print( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU)) print( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
class ReturnFirstClosure(object):
def __init__(self, data):
self._data = data
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
outputs = self._data[idx]
return outputs[0]
if __name__ == "__main__": if __name__ == "__main__":
args = Options().parse() args = Options().parse()
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
......
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################
import os
import copy
import argparse
import numpy as np
from tqdm import tqdm
import torch
from torch.utils import data
import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather
import encoding.utils as utils
from encoding.nn import SegmentationLosses, SyncBatchNorm
from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_dataset
from encoding.models import get_segmentation_model
class Options():
def __init__(self):
parser = argparse.ArgumentParser(description='PyTorch \
Segmentation')
# model and dataset
parser.add_argument('--model', type=str, default='encnet',
help='model name (default: encnet)')
parser.add_argument('--backbone', type=str, default='resnet50',
help='backbone name (default: resnet50)')
parser.add_argument('--dataset', type=str, default='ade20k',
help='dataset name (default: pascal12)')
parser.add_argument('--workers', type=int, default=16,
metavar='N', help='dataloader threads')
parser.add_argument('--base-size', type=int, default=520,
help='base image size')
parser.add_argument('--crop-size', type=int, default=480,
help='crop image size')
parser.add_argument('--train-split', type=str, default='train',
help='dataset train split (default: train)')
# training hyper params
parser.add_argument('--aux', action='store_true', default= False,
help='Auxilary Loss')
parser.add_argument('--aux-weight', type=float, default=0.2,
help='Auxilary loss weight (default: 0.2)')
parser.add_argument('--se-loss', action='store_true', default= False,
help='Semantic Encoding Loss SE-loss')
parser.add_argument('--se-weight', type=float, default=0.2,
help='SE-loss weight (default: 0.2)')
parser.add_argument('--epochs', type=int, default=None, metavar='N',
help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0,
metavar='N', help='start epochs (default:0)')
parser.add_argument('--batch-size', type=int, default=16,
metavar='N', help='input batch size for \
training (default: auto)')
parser.add_argument('--test-batch-size', type=int, default=16,
metavar='N', help='input batch size for \
testing (default: same as batch size)')
# optimizer params
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (default: auto)')
parser.add_argument('--lr-scheduler', type=str, default='poly',
help='learning rate scheduler (default: poly)')
parser.add_argument('--momentum', type=float, default=0.9,
metavar='M', help='momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=1e-4,
metavar='M', help='w-decay (default: 1e-4)')
# cuda, seed and logging
parser.add_argument('--no-cuda', action='store_true', default=
False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
# checking point
parser.add_argument('--resume', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--checkname', type=str, default='default',
help='set the checkpoint name')
parser.add_argument('--model-zoo', type=str, default=None,
help='evaluating on model zoo model')
# finetuning pre-trained models
parser.add_argument('--ft', action='store_true', default= False,
help='finetuning on a different dataset')
# evaluation option
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU')
parser.add_argument('--test-val', action='store_true', default= False,
help='generate masks on val set')
parser.add_argument('--no-val', action='store_true', default= False,
help='skip validation during training')
# test option
parser.add_argument('--test-folder', type=str, default=None,
help='path to test image folder')
# the parser
self.parser = parser
def parse(self):
args = self.parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
# default settings for epochs, batch_size and lr
if args.epochs is None:
epoches = {
'coco': 30,
'pascal_aug': 80,
'pascal_voc': 50,
'pcontext': 80,
'ade20k': 180,
'citys': 240,
}
args.epochs = epoches[args.dataset.lower()]
if args.lr is None:
lrs = {
'coco': 0.004,
'pascal_aug': 0.001,
'pascal_voc': 0.0001,
'pcontext': 0.001,
'ade20k': 0.004,
'citys': 0.004,
}
args.lr = lrs[args.dataset.lower()] / 16 * args.batch_size
print(args)
return args
class Trainer():
def __init__(self, args):
self.args = args
# data transforms
input_transform = transform.Compose([
transform.ToTensor(),
transform.Normalize([.485, .456, .406], [.229, .224, .225])])
# dataset
data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
'crop_size': args.crop_size}
trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
testset = get_dataset(args.dataset, split='val', mode ='val', **data_kwargs)
# dataloader
kwargs = {'num_workers': args.workers, 'pin_memory': True} \
if args.cuda else {}
self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
drop_last=True, shuffle=True, **kwargs)
self.valloader = data.DataLoader(testset, batch_size=args.batch_size,
drop_last=False, shuffle=False, **kwargs)
self.nclass = trainset.num_class
# model
model = get_segmentation_model(args.model, dataset=args.dataset,
backbone = args.backbone, aux = args.aux,
se_loss = args.se_loss, norm_layer = SyncBatchNorm,
base_size=args.base_size, crop_size=args.crop_size)
print(model)
# optimizer using different LR
params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
if hasattr(model, 'head'):
params_list.append({'params': model.head.parameters(), 'lr': args.lr*10})
if hasattr(model, 'auxlayer'):
params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr*10})
optimizer = torch.optim.SGD(params_list, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
# criterions
self.criterion = SegmentationLosses(se_loss=args.se_loss,
aux=args.aux,
nclass=self.nclass,
se_weight=args.se_weight,
aux_weight=args.aux_weight)
self.model, self.optimizer = model, optimizer
# using cuda
if args.cuda:
self.model = DataParallelModel(self.model).cuda()
self.criterion = DataParallelCriterion(self.criterion).cuda()
# resuming checkpoint
if args.resume is not None:
if not os.path.isfile(args.resume):
raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
if args.cuda:
self.model.module.load_state_dict(checkpoint['state_dict'])
else:
self.model.load_state_dict(checkpoint['state_dict'])
if not args.ft:
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_pred = checkpoint['best_pred']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
# clear start epoch if fine-tuning
if args.ft:
args.start_epoch = 0
# lr scheduler
self.scheduler = utils.LR_Scheduler_Head(args.lr_scheduler, args.lr,
args.epochs, len(self.trainloader))
self.best_pred = 0.0
def training(self, epoch):
train_loss = 0.0
self.model.train()
tbar = tqdm(self.trainloader)
for i, (image, target) in enumerate(tbar):
self.scheduler(self.optimizer, i, epoch, self.best_pred)
self.optimizer.zero_grad()
outputs = self.model(image)
loss = self.criterion(outputs, target)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
if self.args.no_val:
# save checkpoint every epoch
is_best = False
utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred,
}, self.args, is_best)
def validation(self, epoch):
# Fast test during the training
def eval_batch(model, image, target):
outputs = model(image)
outputs = gather(outputs, 0, dim=0)
pred = outputs[0]
target = target.cuda()
correct, labeled = utils.batch_pix_accuracy(pred.data, target)
inter, union = utils.batch_intersection_union(pred.data, target, self.nclass)
return correct, labeled, inter, union
is_best = False
self.model.eval()
total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
tbar = tqdm(self.valloader, desc='\r')
for i, (image, target) in enumerate(tbar):
with torch.no_grad():
correct, labeled, inter, union = eval_batch(self.model, image, target)
total_correct += correct
total_label += labeled
total_inter += inter
total_union += union
pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
mIoU = IoU.mean()
tbar.set_description(
'pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
new_pred = (pixAcc + mIoU)/2
if new_pred > self.best_pred:
is_best = True
self.best_pred = new_pred
utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': self.model.module.state_dict(),
'optimizer': self.optimizer.state_dict(),
'best_pred': self.best_pred,
}, self.args, is_best)
if __name__ == "__main__":
args = Options().parse()
torch.manual_seed(args.seed)
trainer = Trainer(args)
print('Starting Epoch:', trainer.args.start_epoch)
print('Total Epoches:', trainer.args.epochs)
if args.eval:
trainer.validation(trainer.args.start_epoch)
else:
for epoch in range(trainer.args.start_epoch, trainer.args.epochs):
trainer.training(epoch)
if not trainer.args.no_val:
trainer.validation(epoch)
...@@ -95,6 +95,8 @@ class Options(): ...@@ -95,6 +95,8 @@ class Options():
# evaluation option # evaluation option
parser.add_argument('--eval', action='store_true', default= False, parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU') help='evaluating mIoU')
parser.add_argument('--export', type=str, default=None,
help='put the path to resuming file if needed')
parser.add_argument('--test-val', action='store_true', default= False, parser.add_argument('--test-val', action='store_true', default= False,
help='generate masks on val set') help='generate masks on val set')
# test option # test option
...@@ -138,22 +140,6 @@ class Options(): ...@@ -138,22 +140,6 @@ class Options():
print(args) print(args)
return args return args
def torch_dist_avg(gpu, *args):
process_group = torch.distributed.group.WORLD
tensor_args = []
pending_res = []
for arg in args:
if isinstance(arg, torch.Tensor):
tensor_arg = arg.clone().reshape(1).detach().cuda(gpu)
else:
tensor_arg = torch.tensor(arg).reshape(1).cuda(gpu)
tensor_args.append(tensor_arg)
pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True))
for res in pending_res:
res.wait()
ret = [x.item()/len(tensor_args) for x in tensor_args]
return ret
def main(): def main():
args = Options().parse() args = Options().parse()
ngpus_per_node = torch.cuda.device_count() ngpus_per_node = torch.cuda.device_count()
...@@ -247,6 +233,7 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -247,6 +233,7 @@ def main_worker(gpu, ngpus_per_node, args):
args.epochs, len(trainloader)) args.epochs, len(trainloader))
def training(epoch): def training(epoch):
train_sampler.set_epoch(epoch)
global best_pred global best_pred
train_loss = 0.0 train_loss = 0.0
model.train() model.train()
...@@ -275,19 +262,23 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -275,19 +262,23 @@ def main_worker(gpu, ngpus_per_node, args):
for i, (image, target) in enumerate(valloader): for i, (image, target) in enumerate(valloader):
with torch.no_grad(): with torch.no_grad():
#correct, labeled, inter, union = eval_batch(model, image, target)
pred = model(image)[0] pred = model(image)[0]
target = target.cuda(args.gpu) target = target.cuda(args.gpu)
metric.update(target, pred) metric.update(target, pred)
pixAcc, mIoU = metric.get() if i % 100 == 0:
if i % 100 == 0 and args.gpu == 0: all_metircs = metric.get_all()
all_metircs = utils.torch_dist_sum(args.gpu, *all_metircs)
pixAcc, mIoU = utils.get_pixacc_miou(*all_metircs)
if args.gpu == 0:
print('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) print('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
all_metircs = metric.get_all()
all_metircs = utils.torch_dist_sum(args.gpu, *all_metircs)
pixAcc, mIoU = utils.get_pixacc_miou(*all_metircs)
if args.gpu == 0: if args.gpu == 0:
pixAcc, mIoU = torch_dist_avg(args.gpu, pixAcc, mIoU)
print('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU)) print('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
if args.eval: return
new_pred = (pixAcc + mIoU)/2 new_pred = (pixAcc + mIoU)/2
if new_pred > best_pred: if new_pred > best_pred:
is_best = True is_best = True
...@@ -299,6 +290,15 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -299,6 +290,15 @@ def main_worker(gpu, ngpus_per_node, args):
'best_pred': best_pred, 'best_pred': best_pred,
}, args, is_best) }, args, is_best)
if args.export:
if args.gpu == 0:
torch.save(model.module.state_dict(), args.export + '.pth')
return
if args.eval:
validation(args.start_epoch)
return
if args.gpu == 0: if args.gpu == 0:
print('Starting Epoch:', args.start_epoch) print('Starting Epoch:', args.start_epoch)
print('Total Epoches:', args.epochs) print('Total Epoches:', args.epochs)
...@@ -306,13 +306,13 @@ def main_worker(gpu, ngpus_per_node, args): ...@@ -306,13 +306,13 @@ def main_worker(gpu, ngpus_per_node, args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
tic = time.time() tic = time.time()
training(epoch) training(epoch)
if epoch % 10 == 0: if epoch % 10 == 0 or epoch == args.epochs - 1:
validation(epoch) validation(epoch)
elapsed = time.time() - tic elapsed = time.time() - tic
if args.gpu == 0: if args.gpu == 0:
print(f'Epoch: {epoch}, Time cost: {elapsed}') print(f'Epoch: {epoch}, Time cost: {elapsed}')
validation(epoch) #validation(epoch)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment