Unverified Commit ce461dae authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

V1.0.0 (#156)

* v1.0
parent c2cb2aab
......@@ -8,7 +8,8 @@ import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import upsample
from torch.nn.functional import interpolate
from ..nn import ConcurrentModule, SyncBatchNorm
from .base import BaseNet
......@@ -38,9 +39,10 @@ class FCN(BaseNet):
>>> model = FCN(nclass=21, backbone='resnet50')
>>> print(model)
"""
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
def __init__(self, nclass, backbone, aux=True, se_loss=False, with_global=False,
norm_layer=SyncBatchNorm, **kwargs):
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
self.head = FCNHead(2048, nclass, norm_layer)
self.head = FCNHead(2048, nclass, norm_layer, self._up_kwargs, with_global)
if aux:
self.auxlayer = FCNHead(1024, nclass, norm_layer)
......@@ -49,19 +51,54 @@ class FCN(BaseNet):
_, _, c3, c4 = self.base_forward(x)
x = self.head(c4)
x = upsample(x, imsize, **self._up_kwargs)
x = interpolate(x, imsize, **self._up_kwargs)
outputs = [x]
if self.aux:
auxout = self.auxlayer(c3)
auxout = upsample(auxout, imsize, **self._up_kwargs)
auxout = interpolate(auxout, imsize, **self._up_kwargs)
outputs.append(auxout)
return tuple(outputs)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class GlobalPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs):
super(GlobalPooling, self).__init__()
self._up_kwargs = up_kwargs
self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
nn.ReLU(True))
def forward(self, x):
_, _, h, w = x.size()
pool = self.gap(x)
return interpolate(pool, (h,w), **self._up_kwargs)
class FCNHead(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer):
def __init__(self, in_channels, out_channels, norm_layer, up_kwargs={}, with_global=False):
super(FCNHead, self).__init__()
inter_channels = in_channels // 4
self._up_kwargs = up_kwargs
if with_global:
self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU(),
ConcurrentModule([
Identity(),
GlobalPooling(inter_channels, inter_channels,
norm_layer, self._up_kwargs),
]),
nn.Dropout2d(0.1, False),
nn.Conv2d(2*inter_channels, out_channels, 1))
else:
self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU(),
......@@ -89,14 +126,8 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
>>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'pcontext': 'pcontext',
'ade20k': 'ade',
}
# infer number of classes
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
from ..datasets import datasets, acronyms
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
......
......@@ -7,10 +7,12 @@ import zipfile
from ..utils import download, check_sha1
_model_sha1 = {name: checksum for checksum, name in [
('ebb6acbbd1d1c90b7f446ae59d30bf70c74febc1', 'resnet50'),
('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
('2e22611a7f3992ebdee6726af169991bc26d7363', 'deepten_minc'),
('da4785cfc837bf00ef95b52fb218feefe703011f', 'wideresnet38'),
('b41562160173ee2e979b795c551d3c7143b1e5b5', 'wideresnet50'),
('1225f149519c7a0113c43a056153c1bb15468ac0', 'deepten_resnet50_minc'),
('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'),
('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'),
('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'),
......
# pylint: disable=wildcard-import, unused-wildcard-import
from .resnet import *
from .cifarresnet import *
from .fcn import *
from .psp import *
from .encnet import *
from .deepten import *
__all__ = ['get_model']
......@@ -25,6 +28,13 @@ def get_model(name, **kwargs):
The model.
"""
models = {
'resnet18': resnet18,
'resnet34': resnet34,
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
'cifar_resnet20': cifar_resnet20,
'deepten_resnet50_minc': get_deepten_resnet50_minc,
'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
......@@ -35,6 +45,6 @@ def get_model(name, **kwargs):
}
name = name.lower()
if name not in models:
raise ValueError('%s\n\t%s' % (str(e), '\n\t'.join(sorted(models.keys()))))
raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
net = models[name](**kwargs)
return net
......@@ -8,7 +8,7 @@ import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import upsample
from torch.nn.functional import interpolate
from .base import BaseNet
from .fcn import FCNHead
......@@ -27,11 +27,11 @@ class PSP(BaseNet):
outputs = []
x = self.head(c4)
x = upsample(x, (h,w), **self._up_kwargs)
x = interpolate(x, (h,w), **self._up_kwargs)
outputs.append(x)
if self.aux:
auxout = self.auxlayer(c3)
auxout = upsample(auxout, (h,w), **self._up_kwargs)
auxout = interpolate(auxout, (h,w), **self._up_kwargs)
outputs.append(auxout)
return tuple(outputs)
......@@ -52,13 +52,8 @@ class PSPHead(nn.Module):
def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.encoding/models', **kwargs):
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
}
# infer number of classes
from ..datasets import datasets
from ..datasets import datasets, acronyms
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
if pretrained:
from .model_store import get_model_file
......
......@@ -4,6 +4,9 @@ import torch
import torch.utils.model_zoo as model_zoo
import torch.nn as nn
from ..nn import GlobalAvgPool2d
from ..models.model_store import get_model_file
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'BasicBlock', 'Bottleneck']
......@@ -132,7 +135,7 @@ class ResNet(nn.Module):
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
# pylint: disable=unused-variable
def __init__(self, block, layers, num_classes=1000, dilated=True,
def __init__(self, block, layers, num_classes=1000, dilated=False, multi_grid=False,
deep_base=True, norm_layer=nn.BatchNorm2d):
self.inplanes = 128 if deep_base else 64
super(ResNet, self).__init__()
......@@ -157,6 +160,11 @@ class ResNet(nn.Module):
if dilated:
self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
dilation=2, norm_layer=norm_layer)
if multi_grid:
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer,
multi_grid=True)
else:
self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
dilation=4, norm_layer=norm_layer)
else:
......@@ -164,7 +172,7 @@ class ResNet(nn.Module):
norm_layer=norm_layer)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
norm_layer=norm_layer)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.avgpool = GlobalAvgPool2d()
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
......@@ -175,7 +183,7 @@ class ResNet(nn.Module):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None):
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
......@@ -185,7 +193,11 @@ class ResNet(nn.Module):
)
layers = []
if dilation == 1 or dilation == 2:
multi_dilations = [4, 8, 16]
if multi_grid:
layers.append(block(self.inplanes, planes, stride, dilation=multi_dilations[0],
downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
elif dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, stride, dilation=1,
downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
elif dilation == 4:
......@@ -196,6 +208,10 @@ class ResNet(nn.Module):
self.inplanes = planes * block.expansion
for i in range(1, blocks):
if multi_grid:
layers.append(block(self.inplanes, planes, dilation=multi_dilations[i],
previous_dilation=dilation, norm_layer=norm_layer))
else:
layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation,
norm_layer=norm_layer))
......@@ -251,7 +267,6 @@ def resnet50(pretrained=False, root='~/.encoding/models', **kwargs):
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet50', root=root)), strict=False)
return model
......@@ -265,7 +280,6 @@ def resnet101(pretrained=False, root='~/.encoding/models', **kwargs):
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet101', root=root)), strict=False)
return model
......@@ -279,7 +293,6 @@ def resnet152(pretrained=False, root='~/.encoding/models', **kwargs):
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
from ..models.model_store import get_model_file
model.load_state_dict(torch.load(
get_model_file('resnet152', root=root)), strict=False)
return model
......@@ -12,3 +12,4 @@
from .encoding import *
from .syncbn import *
from .customize import *
from .loss import *
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
......@@ -10,17 +10,27 @@
"""Encoding Custermized NN Module"""
import torch
from torch.nn import Module, Sequential, Conv2d, ReLU, AdaptiveAvgPool2d, \
NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
torch_ver = torch.__version__[:3]
__all__ = ['GramMatrix', 'SegmentationLosses', 'View', 'Sum', 'Mean',
'Normalize', 'PyramidPooling']
__all__ = ['GlobalAvgPool2d', 'GramMatrix',
'View', 'Sum', 'Mean', 'Normalize', 'ConcurrentModule',
'PyramidPooling']
class GramMatrix(Module):
class GlobalAvgPool2d(nn.Module):
def __init__(self):
"""Global average pooling over the input's spatial dimensions"""
super(GlobalAvgPool2d, self).__init__()
def forward(self, inputs):
return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)
class GramMatrix(nn.Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
.. math::
......@@ -33,60 +43,7 @@ class GramMatrix(Module):
gram = features.bmm(features_t) / (ch * h * w)
return gram
def softmax_crossentropy(input, target, weight, size_average, ignore_index, reduce=True):
return F.nll_loss(F.log_softmax(input, 1), target, weight,
size_average, ignore_index, reduce)
class SegmentationLosses(CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
size_average=True, ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, size_average, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = BCELoss(weight, size_average)
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
class View(Module):
class View(nn.Module):
"""Reshape the input into different size, an inplace operator, support
SelfParallel mode.
"""
......@@ -101,7 +58,7 @@ class View(Module):
return input.view(self.size)
class Sum(Module):
class Sum(nn.Module):
def __init__(self, dim, keep_dim=False):
super(Sum, self).__init__()
self.dim = dim
......@@ -111,7 +68,7 @@ class Sum(Module):
return input.sum(self.dim, self.keep_dim)
class Mean(Module):
class Mean(nn.Module):
def __init__(self, dim, keep_dim=False):
super(Mean, self).__init__()
self.dim = dim
......@@ -121,7 +78,7 @@ class Mean(Module):
return input.mean(self.dim, self.keep_dim)
class Normalize(Module):
class Normalize(nn.Module):
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Does:
......@@ -148,39 +105,54 @@ class Normalize(Module):
def forward(self, x):
return F.normalize(x, self.p, self.dim, eps=1e-8)
class ConcurrentModule(nn.ModuleList):
r"""Feed to a list of modules concurrently.
The outputs of the layers are concatenated at channel dimension.
Args:
modules (iterable, optional): an iterable of modules to add
"""
def __init__(self, modules=None):
super(ConcurrentModule, self).__init__(modules)
def forward(self, x):
outputs = []
for layer in self:
outputs.append(layer(x))
return torch.cat(outputs, 1)
class PyramidPooling(Module):
class PyramidPooling(nn.Module):
"""
Reference:
Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
"""
def __init__(self, in_channels, norm_layer, up_kwargs):
super(PyramidPooling, self).__init__()
self.pool1 = AdaptiveAvgPool2d(1)
self.pool2 = AdaptiveAvgPool2d(2)
self.pool3 = AdaptiveAvgPool2d(3)
self.pool4 = AdaptiveAvgPool2d(6)
self.pool1 = nn.AdaptiveAvgPool2d(1)
self.pool2 = nn.AdaptiveAvgPool2d(2)
self.pool3 = nn.AdaptiveAvgPool2d(3)
self.pool4 = nn.AdaptiveAvgPool2d(6)
out_channels = int(in_channels/4)
self.conv1 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
self.conv2 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
nn.ReLU(True))
self.conv2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
self.conv3 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
nn.ReLU(True))
self.conv3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
self.conv4 = Sequential(Conv2d(in_channels, out_channels, 1, bias=False),
nn.ReLU(True))
self.conv4 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
ReLU(True))
# bilinear upsample options
nn.ReLU(True))
# bilinear interpolate options
self._up_kwargs = up_kwargs
def forward(self, x):
_, _, h, w = x.size()
feat1 = F.upsample(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
feat2 = F.upsample(self.conv2(self.pool2(x)), (h, w), **self._up_kwargs)
feat3 = F.upsample(self.conv3(self.pool3(x)), (h, w), **self._up_kwargs)
feat4 = F.upsample(self.conv4(self.pool4(x)), (h, w), **self._up_kwargs)
feat1 = F.interpolate(self.conv1(self.pool1(x)), (h, w), **self._up_kwargs)
feat2 = F.interpolate(self.conv2(self.pool2(x)), (h, w), **self._up_kwargs)
feat3 = F.interpolate(self.conv3(self.pool3(x)), (h, w), **self._up_kwargs)
feat4 = F.interpolate(self.conv4(self.pool4(x)), (h, w), **self._up_kwargs)
return torch.cat((x, feat1, feat2, feat3, feat4), 1)
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
__all__ = ['SegmentationLosses', 'OhemCrossEntropy2d', 'OHEMSegmentationLosses']
class SegmentationLosses(nn.CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
# adapted from https://github.com/PkuRainBow/OCNet/blob/master/utils/loss.py
class OhemCrossEntropy2d(nn.Module):
def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True):
super(OhemCrossEntropy2d, self).__init__()
self.ignore_label = ignore_label
self.thresh = float(thresh)
self.min_kept = int(min_kept)
if use_weight:
print("w/ class balance")
weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
1.0865, 1.1529, 1.0507])
self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
else:
print("w/o class balance")
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
def forward(self, predict, target, weight=None):
"""
Args:
predict:(n, c, h, w)
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
assert not target.requires_grad
assert predict.dim() == 4
assert target.dim() == 3
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
n, c, h, w = predict.size()
input_label = target.data.cpu().numpy().ravel().astype(np.int32)
x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0).reshape((1, -1)))
input_prob /= input_prob.sum(axis=0).reshape((1, -1))
valid_flag = input_label != self.ignore_label
valid_inds = np.where(valid_flag)[0]
label = input_label[valid_flag]
num_valid = valid_flag.sum()
if self.min_kept >= num_valid:
print('Labels: {}'.format(num_valid))
elif num_valid > 0:
prob = input_prob[:,valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = pred.argsort()
threshold_index = index[ min(len(index), self.min_kept) - 1 ]
if pred[threshold_index] > self.thresh:
threshold = pred[threshold_index]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
label = input_label[valid_inds].copy()
input_label.fill(self.ignore_label)
input_label[valid_inds] = label
valid_flag_new = input_label != self.ignore_label
# print(np.sum(valid_flag_new))
target = Variable(torch.from_numpy(input_label.reshape(target.size())).long().cuda())
return self.criterion(predict, target)
class OHEMSegmentationLosses(OhemCrossEntropy2d):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(OHEMSegmentationLosses, self).__init__(ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(OHEMSegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(OHEMSegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
......@@ -9,117 +9,23 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Synchronized Cross-GPU Batch Normalization Module"""
import collections
import threading
import warnings
try:
from queue import Queue
except ImportError:
from Queue import Queue
import torch
from torch.nn import Module, Sequential, Conv1d, Conv2d, ConvTranspose2d, \
ReLU, Sigmoid, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, Dropout2d, Linear, \
DataParallel
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.functional import batch_norm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from ..utils.misc import EncodingDeprecationWarning
from ..functions import *
from ..parallel import allreduce
from .comm import SyncMaster
__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'Module', 'Sequential', 'Conv1d',
'Conv2d', 'ConvTranspose2d', 'ReLU', 'Sigmoid', 'MaxPool2d', 'AvgPool2d',
'AdaptiveAvgPool2d', 'Dropout2d', 'Linear']
class _SyncBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
if not self.training:
return batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input_shape[0], self.num_features, -1)
# sum(x) and sum(x^2)
N = input.size(0) * input.size(2)
xsum, xsqsum = sum_square(input)
# all-reduce for global sum(x) and sum(x^2)
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N))
# forward
return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
__all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, (bias_var + self.eps) ** -0.5
# API adapted from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
_ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class BatchNorm1d(_SyncBatchNorm):
r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
super(BatchNorm2d, self)._check_input_dim(input)
class BatchNorm2d(_SyncBatchNorm):
class SyncBatchNorm(_BatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device (GPU).
......@@ -127,11 +33,6 @@ class BatchNorm2d(_SyncBatchNorm):
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. note::
We adapt the awesome python API from another `PyTorch SyncBN Implementation
<https://github.com/vacancy/Synchronized-BatchNorm-PyTorch>`_ and provide
efficient CUDA backend.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
......@@ -155,8 +56,12 @@ class BatchNorm2d(_SyncBatchNorm):
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
sync: a boolean value that when set to ``True``, synchronize across
different gpus. Default: ``True``
activation : str
Name of the activation functions, one of: `leaky_relu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
Shape:
- Input: :math:`(N, C, H, W)`
......@@ -167,79 +72,89 @@ class BatchNorm2d(_SyncBatchNorm):
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples:
>>> m = BatchNorm2d(100)
>>> m = SyncBatchNorm(100)
>>> net = torch.nn.DataParallel(m)
>>> encoding.parallel.patch_replication_callback(net)
>>> output = net(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
super(BatchNorm2d, self)._check_input_dim(input)
class BatchNorm3d(_SyncBatchNorm):
r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(BatchNorm3d, self)._check_input_dim(input)
class SharedTensor(object):
"""Shared Tensor for cross GPU all reduce operation"""
def __init__(self, nGPUs):
self.mutex = threading.Lock()
self.all_tasks_done = threading.Condition(self.mutex)
self.nGPUs = nGPUs
self._clear()
def _clear(self):
self.N = 0
self.dict = {}
self.push_tasks = self.nGPUs
self.reduce_tasks = self.nGPUs
def push(self, *inputs):
# push from device
with self.mutex:
if self.push_tasks == 0:
self._clear()
self.N += inputs[0]
igpu = inputs[1]
self.dict[igpu] = inputs[2:]
#idx = self.nGPUs - self.push_tasks
self.push_tasks -= 1
with self.all_tasks_done:
if self.push_tasks == 0:
self.all_tasks_done.notify_all()
while self.push_tasks:
self.all_tasks_done.wait()
def pull(self, igpu):
# pull from device
with self.mutex:
if igpu == 0:
assert(len(self.dict) == self.nGPUs)
# flatten the tensors
self.list = [t for i in range(len(self.dict)) for t in self.dict[i]]
self.outlist = allreduce(2, *self.list)
self.reduce_tasks -= 1
else:
self.reduce_tasks -= 1
with self.all_tasks_done:
if self.reduce_tasks == 0:
self.all_tasks_done.notify_all()
while self.reduce_tasks:
self.all_tasks_done.wait()
# all reduce done
return self.N, self.outlist[2*igpu], self.outlist[2*igpu+1]
def __len__(self):
return self.nGPUs
def __repr__(self):
return ('SharedTensor')
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none", slope=0.01,
inplace=True):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True)
self.activation = activation
self.inplace = False if activation == 'none' else inplace
#self.inplace = inplace
self.slope = slope
self.devices = list(range(torch.cuda.device_count()))
self.sync = sync if len(self.devices) > 1 else False
# Initialize queues
self.worker_ids = self.devices[1:]
self.master_queue = Queue(len(self.worker_ids))
self.worker_queues = [Queue(1) for _ in self.worker_ids]
# running_exs
#self.register_buffer('running_exs', torch.ones(num_features))
def forward(self, x):
# Resize the input to (B, C, -1).
input_shape = x.size()
x = x.view(input_shape[0], self.num_features, -1)
if x.get_device() == self.devices[0]:
# Master mode
extra = {
"is_master": True,
"master_queue": self.master_queue,
"worker_queues": self.worker_queues,
"worker_ids": self.worker_ids
}
else:
# Worker mode
extra = {
"is_master": False,
"master_queue": self.master_queue,
"worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
}
if self.inplace:
return inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
extra, self.sync, self.training, self.momentum, self.eps,
self.activation, self.slope).view(input_shape)
else:
return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
extra, self.sync, self.training, self.momentum, self.eps,
self.activation, self.slope).view(input_shape)
def extra_repr(self):
if self.activation == 'none':
return 'sync={}'.format(self.sync)
else:
return 'sync={}, act={}, slope={}, inplace={}'.format(
self.sync, self.activation, self.slope, self.inplace
)
class BatchNorm1d(SyncBatchNorm):
r"""
.. warning::
BatchNorm1d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
.format('BatchNorm1d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
super(BatchNorm1d, self).__init__(*args, **kwargs)
class BatchNorm2d(SyncBatchNorm):
r"""
.. warning::
BatchNorm2d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
.format('BatchNorm2d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
super(BatchNorm2d, self).__init__(*args, **kwargs)
class BatchNorm3d(SyncBatchNorm):
r"""
.. warning::
BatchNorm3d is deprecated in favor of :class:`encoding.nn.SyncBatchNorm`.
"""
def __init__(self, *args, **kwargs):
warnings.warn("encoding.nn.{} is now deprecated in favor of encoding.nn.{}."
.format('BatchNorm3d', SyncBatchNorm.__name__), EncodingDeprecationWarning)
super(BatchNorm3d, self).__init__(*args, **kwargs)
......@@ -51,7 +51,6 @@ class AllReduce(Function):
outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
class Reduce(Function):
@staticmethod
def forward(ctx, *inputs):
......@@ -98,7 +97,6 @@ class DataParallelModel(DataParallel):
def replicate(self, module, device_ids):
modules = super(DataParallelModel, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
......@@ -133,7 +131,6 @@ class DataParallelCriterion(DataParallel):
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
return Reduce.apply(*outputs) / len(outputs)
#return self.gather(outputs, self.output_device).mean()
def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
......@@ -188,62 +185,3 @@ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices
raise output
outputs.append(output)
return outputs
###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
#
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created
by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead
of calling the callback of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
import torch
from torchvision.transforms import *
def get_transform(dataset, large_test_crop=False):
normalize = Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if dataset == 'imagenet':
transform_train = Compose([
Resize(256),
RandomResizedCrop(224),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
normalize,
])
if large_test_crop:
transform_val = Compose([
Resize(366),
CenterCrop(320),
ToTensor(),
normalize,
])
else:
transform_val = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
normalize,
])
elif dataset == 'minc':
transform_train = Compose([
Resize(256),
RandomResizedCrop(224),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
normalize,
])
transform_val = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
normalize,
])
elif dataset == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
return transform_train, transform_val
_imagenet_pca = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
'eigvec': torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone()\
.mul(alpha.view(1, 3).expand(3, 3))\
.mul(self.eigval.view(1, 3).expand(3, 3))\
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))
......@@ -12,9 +12,10 @@
from .lr_scheduler import LR_Scheduler
from .metrics import SegmentationMetric, batch_intersection_union, batch_pix_accuracy
from .pallete import get_mask_pallete
from .train_helper import get_selabel_vector, EMA
from .train_helper import *
from .presets import load_image
from .files import *
from .misc import *
__all__ = ['LR_Scheduler', 'batch_pix_accuracy', 'batch_intersection_union',
'save_checkpoint', 'download', 'mkdir', 'check_sha1', 'load_image',
......
import warnings
__all__ = ['EncodingDeprecationWarning']
class EncodingDeprecationWarning(DeprecationWarning):
pass
warnings.simplefilter('once', EncodingDeprecationWarning)
......@@ -19,7 +19,7 @@ def get_mask_pallete(npimg, dataset='detail'):
out_img = Image.fromarray(npimg.squeeze().astype('uint8'))
if dataset == 'ade20k':
out_img.putpalette(adepallete)
elif dataset == 'cityscapes':
elif dataset == 'citys':
out_img.putpalette(citypallete)
elif dataset in ('detail', 'pascal_voc', 'pascal_aug'):
out_img.putpalette(vocpallete)
......
......@@ -4,13 +4,13 @@ import numpy as np
import torch
import torchvision.transforms as transform
__all__ = ['load_image', 'subtract_imagenet_mean_batch']
__all__ = ['load_image']
input_transform = transform.Compose([
transform.ToTensor(),
transform.Normalize([.485, .456, .406], [.229, .224, .225])])
def load_image(filename, size=None, scale=None, keep_asp=True):
def load_image(filename, size=None, scale=None, keep_asp=True, transform=input_transform):
"""Load the image for demos"""
img = Image.open(filename).convert('RGB')
if size is not None:
......@@ -22,5 +22,6 @@ def load_image(filename, size=None, scale=None, keep_asp=True):
elif scale is not None:
img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
img = input_transform(img)
if transform:
img = transform(img)
return img
......@@ -9,6 +9,12 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
import torch.nn as nn
#from ..nn import SyncBatchNorm
from torch.nn.modules.batchnorm import _BatchNorm
__all__ = ['get_selabel_vector']
def get_selabel_vector(target, nclass):
r"""Get SE-Loss Label in a batch
......
......@@ -49,7 +49,7 @@ def make_dataset(filename, datadir, class_to_idx):
return images, labels
class MINCDataloder(data.Dataset):
class MINCDataset(data.Dataset):
def __init__(self, root, train=True, transform=None):
self.transform = transform
classes, class_to_idx = find_classes(root + '/images')
......@@ -94,9 +94,9 @@ class Dataloader():
normalize,
])
trainset = MINCDataloder(root=os.path.expanduser('~/.encoding/data/minc-2500/'),
trainset = MINCDataset(root=os.path.expanduser('~/.encoding/data/minc-2500/'),
train=True, transform=transform_train)
testset = MINCDataloder(root=os.path.expanduser('~/.encoding/data/minc-2500/'),
testset = MINCDataset(root=os.path.expanduser('~/.encoding/data/minc-2500/'),
train=False, transform=transform_test)
kwargs = {'num_workers': 8, 'pin_memory': True} if args.cuda else {}
......@@ -133,7 +133,7 @@ class Lighting(object):
if __name__ == "__main__":
trainset = MINCDataloder(root=os.path.expanduser('~/data/minc-2500/'), train=True)
testset = MINCDataloder(root=os.path.expanduser('~/data/minc-2500/'), train=False)
trainset = MINCDataset(root=os.path.expanduser('~/.encoding/data/minc-2500/'), train=True)
testset = MINCDataset(root=os.path.expanduser('~/.encoding/data/minc-2500/'), train=False)
print(len(trainset))
print(len(testset))
......@@ -9,47 +9,45 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import print_function
import os
import matplotlib.pyplot as plot
import importlib
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import encoding
from option import Options
from encoding.utils import *
from tqdm import tqdm
# global variable
best_pred = 100.0
errlist_train = []
errlist_val = []
best_pred = 0.0
acclist_train = []
acclist_val = []
def main():
# init the args
global best_pred, errlist_train, errlist_val
global best_pred, acclist_train, acclist_val
args = Options().parse()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)
torch.manual_seed(args.seed)
# plot
if args.plot:
print('=>Enabling matplotlib for display:')
plot.ion()
plot.show()
if args.cuda:
torch.cuda.manual_seed(args.seed)
# init dataloader
dataset = importlib.import_module('dataset.'+args.dataset)
Dataloader = dataset.Dataloader
train_loader, test_loader = Dataloader(args).getloader()
transform_train, transform_val = encoding.transforms.get_transform(args.dataset)
trainset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
transform=transform_train, train=True, download=True)
valset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
transform=transform_val, train=False, download=True)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
valset, batch_size=args.test_batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# init the model
models = importlib.import_module('model.'+args.model)
model = models.Net(args)
model = encoding.models.get_model(args.model, pretrained=args.pretrained)
print(model)
# criterion and optimizer
criterion = nn.CrossEntropyLoss()
......@@ -58,8 +56,9 @@ def main():
weight_decay=args.weight_decay)
if args.cuda:
model.cuda()
criterion.cuda()
# Please use CUDA_VISIBLE_DEVICES to control the number of gpus
model = torch.nn.DataParallel(model)
model = nn.DataParallel(model)
# check point
if args.resume is not None:
if os.path.isfile(args.resume):
......@@ -67,108 +66,116 @@ def main():
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch'] +1
best_pred = checkpoint['best_pred']
errlist_train = checkpoint['errlist_train']
errlist_val = checkpoint['errlist_val']
model.load_state_dict(checkpoint['state_dict'])
acclist_train = checkpoint['acclist_train']
acclist_val = checkpoint['acclist_val']
model.module.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
raise RuntimeError ("=> no resume checkpoint found at '{}'".\
format(args.resume))
scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
scheduler = encoding.utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
len(train_loader), args.lr_step)
def train(epoch):
model.train()
global best_pred, errlist_train
train_loss, correct, total = 0,0,0
losses = AverageMeter()
top1 = AverageMeter()
global best_pred, acclist_train
tbar = tqdm(train_loader, desc='\r')
for batch_idx, (data, target) in enumerate(tbar):
scheduler(optimizer, batch_idx, epoch, best_pred)
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.data.item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum().item()
total += target.size(0)
err = 100.0 - 100.0 * correct / total
tbar.set_description('\rLoss: %.3f | Err: %.3f%% (%d/%d)' % \
(train_loss/(batch_idx+1), err, total-correct, total))
acc1 = accuracy(output, target, topk=(1,))
top1.update(acc1[0], data.size(0))
losses.update(loss.item(), data.size(0))
tbar.set_description('\rLoss: %.3f | Top1: %.3f'%(losses.avg, top1.avg))
errlist_train += [err]
acclist_train += [top1.avg]
def test(epoch):
def validate(epoch):
model.eval()
global best_pred, errlist_train, errlist_val
test_loss, correct, total = 0,0,0
top1 = AverageMeter()
top5 = AverageMeter()
global best_pred, acclist_train, acclist_val
is_best = False
tbar = tqdm(test_loader, desc='\r')
tbar = tqdm(val_loader, desc='\r')
for batch_idx, (data, target) in enumerate(tbar):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
with torch.no_grad():
output = model(data)
test_loss += criterion(output, target).data.item()
# get the index of the max log-probability
pred = output.data.max(1)[1]
correct += pred.eq(target.data).cpu().sum().item()
total += target.size(0)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], data.size(0))
top5.update(acc5[0], data.size(0))
err = 100.0 - 100.0 * correct / total
tbar.set_description('Loss: %.3f | Err: %.3f%% (%d/%d)'% \
(test_loss/(batch_idx+1), err, total-correct, total))
tbar.set_description('Top1: %.3f | Top5: %.3f'%(top1.avg, top5.avg))
if args.eval:
print('Error rate is %.3f'%err)
print('Top1 Acc: %.3f | Top5 Acc: %.3f '%(top1.avg, top5.avg))
return
# save checkpoint
errlist_val += [err]
if err < best_pred:
best_pred = err
acclist_val += [top1.avg]
if top1.avg > best_pred:
best_pred = top1.avg
is_best = True
save_checkpoint({
encoding.utils.save_checkpoint({
'epoch': epoch,
'state_dict': model.state_dict(),
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'best_pred': best_pred,
'errlist_train':errlist_train,
'errlist_val':errlist_val,
'acclist_train':acclist_train,
'acclist_val':acclist_val,
}, args=args, is_best=is_best)
if args.plot:
plot.clf()
plot.xlabel('Epoches: ')
plot.ylabel('Error Rate: %')
plot.plot(errlist_train, label='train')
plot.plot(errlist_val, label='val')
plot.legend(loc='upper left')
plot.draw()
plot.pause(0.001)
if args.eval:
test(args.start_epoch)
validate(args.start_epoch)
return
for epoch in range(args.start_epoch, args.epochs + 1):
train(epoch)
test(epoch)
# save train_val curve to a file
if args.plot:
plot.clf()
plot.xlabel('Epoches: ')
plot.ylabel('Error Rate: %')
plot.plot(errlist_train, label='train')
plot.plot(errlist_val, label='val')
plot.savefig("runs/%s/%s/"%(args.dataset, args.checkname)
+'train_val.jpg')
validate(epoch)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if __name__ == "__main__":
main()
......@@ -14,7 +14,7 @@ import torch.nn as nn
from torch.autograd import Variable
import encoding
import encoding.dilated.resnet as resnet
import encoding.models.resnet as resnet
class Net(nn.Module):
def __init__(self, args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment