Unverified Commit f8919197 authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files
parent d4e19553
......@@ -7,8 +7,10 @@ setup(
CUDAExtension('enclib_gpu', [
'operator.cpp',
'encoding_kernel.cu',
'encodingv2_kernel.cu',
'syncbn_kernel.cu',
'roi_align_kernel.cu',
'nms_kernel.cu',
]),
],
cmdclass={
......
#include <ATen/ATen.h>
#include <vector>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "common.h"
#include "device_tensor.h"
......@@ -180,7 +181,7 @@ at::Tensor BatchNorm_Forward_CUDA(
const at::Tensor gamma_,
const at::Tensor beta_) {
auto output_ = at::zeros_like(input_);
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Forward_CUDA", ([&] {
......@@ -214,7 +215,7 @@ std::vector<at::Tensor> BatchNorm_Backward_CUDA(
at::Tensor gradMean_ = at::zeros_like(mean_);
at::Tensor gradStd_ = at::zeros_like(std_);
/* cuda utils*/
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
......@@ -246,10 +247,10 @@ std::vector<at::Tensor> Sum_Square_Forward_CUDA(
at::Tensor sum_ = input_.type().tensor({input_.size(1)}).zero_();
at::Tensor square_ = input_.type().tensor({input_.size(1)}).zero_();
/* cuda utils*/
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_forward_CUDA", ([&] {
/* Device tensors */
DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
DeviceTensor<scalar_t, 1> sum = devicetensor<scalar_t, 1>(sum_);
......@@ -269,10 +270,10 @@ at::Tensor Sum_Square_Backward_CUDA(
/* outputs */
at::Tensor gradInput_ = at::zeros_like(input_);
/* cuda utils*/
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_Backward_CUDA", ([&] {
/* Device tensors */
DeviceTensor<scalar_t, 3> gradInput = devicetensor<scalar_t, 3>(gradInput_);
DeviceTensor<scalar_t, 3> input = devicetensor<scalar_t, 3>(input_);
......
......@@ -24,7 +24,7 @@ __all__ = ['BaseNet', 'MultiEvalModule']
class BaseNet(nn.Module):
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
base_size=576, crop_size=608, mean=[.485, .456, .406],
base_size=520, crop_size=480, mean=[.485, .456, .406],
std=[.229, .224, .225], root='~/.encoding/models'):
super(BaseNet, self).__init__()
self.nclass = nclass
......@@ -99,6 +99,8 @@ class MultiEvalModule(DataParallel):
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
outputs = self.parallel_apply(replicas, inputs, kwargs)
#for out in outputs:
# print('out.size()', out.size())
return outputs
def forward(self, image):
......
......@@ -14,7 +14,8 @@ from .base import BaseNet
from .fcn import FCNHead
__all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_pcontext',
'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade']
'get_encnet_resnet101_pcontext', 'get_encnet_resnet50_ade',
'get_encnet_resnet101_ade']
class EncNet(BaseNet):
def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
......@@ -43,8 +44,6 @@ class EncNet(BaseNet):
class EncModule(nn.Module):
def __init__(self, in_channels, nclass, ncodes=32, se_loss=True, norm_layer=None):
super(EncModule, self).__init__()
#norm_layer = nn.BatchNorm1d if isinstance(norm_layer, nn.BatchNorm2d) else \
# encoding.nn.BatchNorm1d
self.se_loss = se_loss
self.encoding = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 1, bias=False),
......@@ -140,9 +139,9 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
'ade20k': 'ade',
'pcontext': 'pcontext',
}
kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
kwargs['lateral'] = True if dataset.lower().startswith('p') else False
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
from ..datasets import datasets
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
......@@ -167,7 +166,8 @@ def get_encnet_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **
>>> model = get_encnet_resnet50_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=True, **kwargs)
return get_encnet('pcontext', 'resnet50', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
......@@ -186,7 +186,8 @@ def get_encnet_resnet101_pcontext(pretrained=False, root='~/.encoding/models', *
>>> model = get_encnet_resnet101_pcontext(pretrained=True)
>>> print(model)
"""
return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=True, **kwargs)
return get_encnet('pcontext', 'resnet101', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
......@@ -205,4 +206,45 @@ def get_encnet_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwarg
>>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_encnet('ade20k', 'resnet50', pretrained, root=root, aux=True, **kwargs)
return get_encnet('ade20k', 'resnet50', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs)
def get_encnet_resnet101_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_encnet('ade20k', 'resnet101', pretrained, root=root, aux=True,
base_size=640, crop_size=576, **kwargs)
def get_encnet_resnet152_ade(pretrained=False, root='~/.encoding/models', **kwargs):
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
<https://arxiv.org/pdf/1803.08904.pdf>`_
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.encoding/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_encnet_resnet50_ade(pretrained=True)
>>> print(model)
"""
return get_encnet('ade20k', 'resnet152', pretrained, root=root, aux=True,
base_size=520, crop_size=480, **kwargs)
......@@ -101,8 +101,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('fcn_%s_%s'%(backbone, acronyms[dataset]), root=root)),
strict= False)
get_model_file('fcn_%s_%s'%(backbone, acronyms[dataset]), root=root)))
return model
def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwargs):
......
......@@ -7,15 +7,19 @@ import zipfile
from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [
('853f2fb07aeb2927f7696e166b215609a987fd44', 'resnet50'),
('5be5422ad7cb6a2e5f5a54070d0aa9affe69a9a4', 'resnet101'),
('6cb047cda851de6aa31963e779fae5f4c299056a', 'deepten_minc'),
('ebb6acbbd1d1c90b7f446ae59d30bf70c74febc1', 'resnet50'),
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
('2e22611a7f3992ebdee6726af169991bc26d7363', 'deepten_minc'),
('fc8c0b795abf0133700c2d4265d2f9edab7eb6cc', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
('558e8904e123813f23dc0347acba85224650fe5f', 'encnet_resnet50_ade'),
('7846a2f065e90ce70d268ba8ada1a92251587734', 'encnet_resnet50_pcontext'),
('6f7c372259988bc2b6d7fc0007182e7835c31a11', 'encnet_resnet101_pcontext'),
('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'),
('5ee47ee28b480cc781a195d13b5806d5bbc616bf', 'encnet_resnet101_coco'),
('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50_pcontext'),
('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'),
('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'),
('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'),
]}
encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
......@@ -52,9 +56,10 @@ def get_model_file(name, root=os.path.join('~', '.encoding', 'models')):
if check_sha1(file_path, sha1_hash):
return file_path
else:
print('Mismatch in the content of model file detected. Downloading again.')
print('Mismatch in the content of model file {} detected.' +
' Downloading again.'.format(file_path))
else:
print('Model file is not found. Downloading.')
print('Model file {} is not found. Downloading.'.format(file_path))
if not os.path.exists(root):
os.makedirs(root)
......
......@@ -3,6 +3,7 @@
from .fcn import *
from .psp import *
from .encnet import *
from .deeplab import *
__all__ = ['get_model']
......@@ -29,8 +30,10 @@ def get_model(name, **kwargs):
'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
'encnet_resnet50_ade': get_encnet_resnet50_ade,
'encnet_resnet101_ade': get_encnet_resnet101_ade,
'fcn_resnet50_ade': get_fcn_resnet50_ade,
'psp_resnet50_ade': get_psp_resnet50_ade,
'deeplab_resnet50_ade': get_deeplab_resnet50_ade,
}
name = name.lower()
if name not in models:
......
......@@ -58,7 +58,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
'ade20k': 'ade',
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
from ..datasets import datasets
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
......
......@@ -40,7 +40,7 @@ def softmax_crossentropy(input, target, weight, size_average, ignore_index, redu
class SegmentationLosses(CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.2, weight=None,
aux=False, aux_weight=0.4, weight=None,
size_average=True, ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, size_average, ignore_index)
self.se_loss = se_loss
......@@ -62,14 +62,14 @@ class SegmentationLosses(CrossEntropyLoss):
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(F.sigmoid(se_pred), se_target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(F.sigmoid(se_pred), se_target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod
......
......@@ -15,7 +15,7 @@ import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.utils import _pair
from ..functions import scaledL2, aggregate, pairwise_cosine
from ..functions import scaled_l2, aggregate, pairwise_cosine
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d']
......@@ -90,18 +90,17 @@ class Encoding(Module):
def forward(self, X):
# input X is a 4D tensor
assert(X.size(1) == self.D)
B, D = X.size(0), self.D
if X.dim() == 3:
# BxDxN
B, D = X.size(0), self.D
# BxDxN => BxNxD
X = X.transpose(1, 2).contiguous()
elif X.dim() == 4:
# BxDxHxW
B, D = X.size(0), self.D
# BxDxHxW => Bx(HW)xD
X = X.view(B, D, -1).transpose(1, 2).contiguous()
else:
raise RuntimeError('Encoding Layer unknown input dims!')
# assignment weights NxKxD
A = F.softmax(scaledL2(X, self.codewords, self.scale), dim=1)
# assignment weights BxNxK
A = F.softmax(scaled_l2(X, self.codewords, self.scale), dim=2)
# aggregate
E = aggregate(A, X, self.codewords)
return E
......@@ -149,7 +148,7 @@ class EncodingDrop(Module):
raise RuntimeError('Encoding Layer unknown input dims!')
self._drop()
# assignment weights
A = F.softmax(scaledL2(X, self.codewords, self.scale), dim=1)
A = F.softmax(scaled_l2(X, self.codewords, self.scale), dim=2)
# aggregate
E = aggregate(A, X, self.codewords)
self._drop()
......
......@@ -10,7 +10,7 @@
"""Encoding Util Tools"""
from .lr_scheduler import LR_Scheduler
from .metrics import batch_intersection_union, batch_pix_accuracy
from .metrics import SegmentationMetric, batch_intersection_union, batch_pix_accuracy
from .pallete import get_mask_pallete
from .train_helper import get_selabel_vector, EMA
from .presets import load_image
......
......@@ -8,18 +8,70 @@
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import threading
import numpy as np
import torch
def batch_pix_accuracy(predict, target):
class SegmentationMetric(object):
"""Computes pixAcc and mIoU metric scroes
"""
def __init__(self, nclass):
self.nclass = nclass
self.lock = threading.Lock()
self.reset()
def update(self, labels, preds):
def evaluate_worker(self, label, pred):
correct, labeled = batch_pix_accuracy(
pred, label)
inter, union = batch_intersection_union(
pred, label, self.nclass)
with self.lock:
self.total_correct += correct
self.total_label += labeled
self.total_inter += inter
self.total_union += union
return
if isinstance(preds, torch.Tensor):
evaluate_worker(self, labels, preds)
elif isinstance(preds, (list, tuple)):
threads = [threading.Thread(target=evaluate_worker,
args=(self, label, pred),
)
for (label, pred) in zip(labels, preds)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
raise NotImplemented
def get(self):
pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)
IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)
mIoU = IoU.mean()
return pixAcc, mIoU
def reset(self):
self.total_inter = 0
self.total_union = 0
self.total_correct = 0
self.total_label = 0
return
def batch_pix_accuracy(output, target):
"""Batch Pixel Accuracy
Args:
predict: input 4D tensor
target: label 3D tensor
"""
_, predict = torch.max(predict, 1)
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
_, predict = torch.max(output, 1)
predict = predict.cpu().numpy().astype('int64') + 1
target = target.cpu().numpy().astype('int64') + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target)*(target > 0))
assert pixel_correct <= pixel_labeled, \
......@@ -27,19 +79,19 @@ def batch_pix_accuracy(predict, target):
return pixel_correct, pixel_labeled
def batch_intersection_union(predict, target, nclass):
def batch_intersection_union(output, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 4D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
_, predict = torch.max(predict, 1)
_, predict = torch.max(output, 1)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
predict = predict.cpu().numpy().astype('int64') + 1
target = target.cpu().numpy().astype('int64') + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
......
......@@ -21,11 +21,10 @@ def get_mask_pallete(npimg, dataset='detail'):
out_img.putpalette(adepallete)
elif dataset == 'cityscapes':
out_img.putpalette(citypallete)
else:
elif dataset in ('detail', 'pascal_voc', 'pascal_aug'):
out_img.putpalette(vocpallete)
return out_img
def _get_voc_pallete(num_cls):
n = num_cls
pallete = [0]*(n*3)
......
......@@ -94,9 +94,9 @@ class Dataloader():
normalize,
])
trainset = MINCDataloder(root=os.path.expanduser('~/data/minc-2500/'),
trainset = MINCDataloder(root=os.path.expanduser('~/.encoding/data/minc-2500/'),
train=True, transform=transform_train)
testset = MINCDataloder(root=os.path.expanduser('~/data/minc-2500/'),
testset = MINCDataloder(root=os.path.expanduser('~/.encoding/data/minc-2500/'),
train=False, transform=transform_test)
kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
......
......@@ -96,7 +96,7 @@ def main():
train_loss += loss.data.item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum()
correct += pred.eq(target.data).cpu().sum().item()
total += target.size(0)
err = 100.0 - 100.0 * correct / total
tbar.set_description('\rLoss: %.3f | Err: %.3f%% (%d/%d)' % \
......
......@@ -2,4 +2,4 @@ import encoding
import shutil
encoding.models.get_model_file('deepten_minc', root='./')
shutil.move('deepten_minc-6cb047cd.pth', 'deepten_minc.pth')
shutil.move('deepten_minc-2e22611a.pth', 'deepten_minc.pth')
......@@ -86,6 +86,46 @@ class Bottleneck(nn.Module):
return residual + self.conv_block(x)
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
class EncLayerV2(nn.Module):
def __init__(self, channel, K=16, reduction=4):
super(EncLayerV2, self).__init__()
out_channel = int(channel / reduction)
self.fc = nn.Sequential(
nn.Conv2d(channel, out_channel, 1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
encoding.nn.EncodingV2(D=out_channel,K=K),
encoding.nn.View(-1, out_channel*K),
encoding.nn.Normalize(),
nn.Linear(out_channel*K, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.fc(x).view(b, c, 1, 1)
return x * y
class EncLayerV3(nn.Module):
def __init__(self, channel, K=16, reduction=4):
super(EncLayerV3, self).__init__()
out_channel = int(channel / reduction)
self.fc = nn.Sequential(
nn.Conv2d(channel, out_channel, 1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
encoding.nn.EncodingV3(D=out_channel,K=K),
encoding.nn.View(-1, out_channel*K),
encoding.nn.Normalize(),
nn.Linear(out_channel*K, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.fc(x).view(b, c, 1, 1)
return x * y
class EncLayer(nn.Module):
def __init__(self, channel, K=16, reduction=4):
super(EncLayer, self).__init__()
......
......@@ -40,8 +40,8 @@ class Options():
# lr setting
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)')
parser.add_argument('--lr-scheduler', type=str, default='step',
help='learning rate scheduler (default: step)')
parser.add_argument('--lr-scheduler', type=str, default='cos',
help='learning rate scheduler (default: cos)')
parser.add_argument('--lr-step', type=int, default=40, metavar='LR',
help='learning rate step (default: 40)')
# optimizer
......
......@@ -25,15 +25,21 @@ class Options():
$(HOME)/data)')
parser.add_argument('--workers', type=int, default=16,
metavar='N', help='dataloader threads')
parser.add_argument('--base-size', type=int, default=608,
parser.add_argument('--base-size', type=int, default=520,
help='base image size')
parser.add_argument('--crop-size', type=int, default=576,
parser.add_argument('--crop-size', type=int, default=480,
help='crop image size')
parser.add_argument('--train-split', type=str, default='train',
help='dataset train split (default: train)')
# training hyper params
parser.add_argument('--aux', action='store_true', default= False,
help='Auxilary Loss')
parser.add_argument('--aux-weight', type=float, default=0.2,
help='Auxilary loss weight (default: 0.2)')
parser.add_argument('--se-loss', action='store_true', default= False,
help='Semantic Encoding Loss SE-loss')
parser.add_argument('--se-weight', type=float, default=0.2,
help='SE-loss weight (default: 0.2)')
parser.add_argument('--epochs', type=int, default=None, metavar='N',
help='number of epochs to train (default: auto)')
parser.add_argument('--start_epoch', type=int, default=0,
......@@ -68,12 +74,7 @@ class Options():
# finetuning pre-trained models
parser.add_argument('--ft', action='store_true', default= False,
help='finetuning on a different dataset')
parser.add_argument('--pre-class', type=int, default=None,
help='num of pre-trained classes \
(default: None)')
# evaluation option
parser.add_argument('--ema', action='store_true', default= False,
help='using EMA evaluation')
parser.add_argument('--eval', action='store_true', default= False,
help='evaluating mIoU')
parser.add_argument('--no-val', action='store_true', default= False,
......@@ -90,10 +91,12 @@ class Options():
# default settings for epochs, batch_size and lr
if args.epochs is None:
epoches = {
'coco': 30,
'citys': 180,
'pascal_voc': 50,
'pascal_aug': 50,
'pcontext': 80,
'ade20k': 160,
'ade20k': 120,
}
args.epochs = epoches[args.dataset.lower()]
if args.batch_size is None:
......@@ -102,10 +105,13 @@ class Options():
args.test_batch_size = args.batch_size
if args.lr is None:
lrs = {
'coco': 0.01,
'citys': 0.01,
'pascal_voc': 0.0001,
'pascal_aug': 0.001,
'pcontext': 0.001,
'ade20k': 0.01,
}
args.lr = lrs[args.dataset.lower()] / 16 * args.batch_size
print(args)
return args
......@@ -21,10 +21,6 @@ from encoding.models import get_model, get_segmentation_model, MultiEvalModule
from option import Options
torch_ver = torch.__version__[:3]
if torch_ver == '0.3':
from torch.autograd import Variable
def test(args):
# output folder
outdir = 'outdir'
......@@ -64,58 +60,29 @@ def test(args):
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
print(model)
evaluator = MultiEvalModule(model, testset.num_class).cuda()
scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
evaluator.eval()
metric = utils.SegmentationMetric(testset.num_class)
tbar = tqdm(test_data)
def eval_batch(image, dst, evaluator, eval_mode):
if eval_mode:
# evaluation mode on validation set
targets = dst
outputs = evaluator.parallel_forward(image)
batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
for output, target in zip(outputs, targets):
correct, labeled = utils.batch_pix_accuracy(output.data.cpu(), target)
inter, union = utils.batch_intersection_union(
output.data.cpu(), target, testset.num_class)
batch_correct += correct
batch_label += labeled
batch_inter += inter
batch_union += union
return batch_correct, batch_label, batch_inter, batch_union
for i, (image, dst) in enumerate(tbar):
if args.eval:
with torch.no_grad():
predicts = evaluator.parallel_forward(image)
metric.update(dst, predicts)
pixAcc, mIoU = metric.get()
tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
else:
# test mode, dump the results
im_paths = dst
outputs = evaluator.parallel_forward(image)
predicts = [torch.max(output, 1)[1].cpu().numpy() + testset.pred_offset
for output in outputs]
for predict, impath in zip(predicts, im_paths):
with torch.no_grad():
outputs = evaluator.parallel_forward(image)
predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
for output in outputs]
for predict, impath in zip(predicts, dst):
mask = utils.get_mask_pallete(predict, args.dataset)
outname = os.path.splitext(impath)[0] + '.png'
mask.save(os.path.join(outdir, outname))
# dummy outputs for compatible with eval mode
return 0, 0, 0, 0
total_inter, total_union, total_correct, total_label = \
np.int64(0), np.int64(0), np.int64(0), np.int64(0)
for i, (image, dst) in enumerate(tbar):
if torch_ver == "0.3":
image = Variable(image, volatile=True)
correct, labeled, inter, union = eval_batch(image, dst, evaluator, args.eval)
else:
with torch.no_grad():
correct, labeled, inter, union = eval_batch(image, dst, evaluator, args.eval)
if args.eval:
total_correct += correct.astype('int64')
total_label += labeled.astype('int64')
total_inter += inter.astype('int64')
total_union += union.astype('int64')
pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
mIoU = IoU.mean()
tbar.set_description(
'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
if __name__ == "__main__":
args = Options().parse()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment