Unverified Commit b872eb8c authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

ResNeSt plus (#256)

parent 5a1e3fbc
...@@ -12,4 +12,8 @@ ...@@ -12,4 +12,8 @@
from .encoding import * from .encoding import *
from .syncbn import * from .syncbn import *
from .customize import * from .customize import *
from .attention import *
from .loss import * from .loss import *
from .rectify import *
from .splat import SplAtConv2d
from .dropblock import *
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2018
###########################################################################
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .syncbn import SyncBatchNorm
__all__ = ['ACFModule', 'MixtureOfSoftMaxACF']
class ACFModule(nn.Module):
""" Multi-Head Attention module """
def __init__(self, n_head, n_mix, d_model, d_k, d_v, norm_layer=SyncBatchNorm,
kq_transform='conv', value_transform='conv',
pooling=True, concat=False, dropout=0.1):
super(ACFModule, self).__init__()
self.n_head = n_head
self.n_mix = n_mix
self.d_k = d_k
self.d_v = d_v
self.pooling = pooling
self.concat = concat
if self.pooling:
self.pool = nn.AvgPool2d(3, 2, 1, count_include_pad=False)
if kq_transform == 'conv':
self.conv_qs = nn.Conv2d(d_model, n_head*d_k, 1)
nn.init.normal_(self.conv_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
elif kq_transform == 'ffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=1, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
elif kq_transform == 'dffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=4, dilation=4, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
else:
raise NotImplemented
#self.conv_ks = nn.Conv2d(d_model, n_head*d_k, 1)
self.conv_ks = self.conv_qs
if value_transform == 'conv':
self.conv_vs = nn.Conv2d(d_model, n_head*d_v, 1)
else:
raise NotImplemented
#nn.init.normal_(self.conv_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.conv_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = MixtureOfSoftMaxACF(n_mix=n_mix, d_k=d_k)
self.conv = nn.Conv2d(n_head*d_v, d_model, 1, bias=False)
self.norm_layer = norm_layer(d_model)
def forward(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
output, attn = self.attention(qt, kt, vt)
output = output.transpose(1, 2).contiguous().view(b_, n_head*d_v, h_, w_)
output = self.conv(output)
if self.concat:
output = torch.cat((self.norm_layer(output), residual), 1)
else:
output = self.norm_layer(output) + residual
return output
def demo(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
_, attn = self.attention(qt, kt, vt)
attn.view(b_, n_head, h_*w_, -1)
return attn
def extra_repr(self):
return 'n_head={}, n_mix={}, d_k={}, pooling={}' \
.format(self.n_head, self.n_mix, self.d_k, self.pooling)
class MixtureOfSoftMaxACF(nn.Module):
""""Mixture of SoftMax"""
def __init__(self, n_mix, d_k, attn_dropout=0.1):
super(MixtureOfSoftMaxACF, self).__init__()
self.temperature = np.power(d_k, 0.5)
self.n_mix = n_mix
self.att_drop = attn_dropout
self.dropout = nn.Dropout(attn_dropout)
self.softmax1 = nn.Softmax(dim=1)
self.softmax2 = nn.Softmax(dim=2)
self.d_k = d_k
if n_mix > 1:
self.weight = nn.Parameter(torch.Tensor(n_mix, d_k))
std = np.power(n_mix, -0.5)
self.weight.data.uniform_(-std, std)
def forward(self, qt, kt, vt):
B, d_k, N = qt.size()
m = self.n_mix
assert d_k == self.d_k
d = d_k // m
if m > 1:
# \bar{v} \in R^{B, d_k, 1}
bar_qt = torch.mean(qt, 2, True)
# pi \in R^{B, m, 1}
pi = self.softmax1(torch.matmul(self.weight, bar_qt)).view(B*m, 1, 1)
# reshape for n_mix
q = qt.view(B*m, d, N).transpose(1, 2)
N2 = kt.size(2)
kt = kt.view(B*m, d, N2)
v = vt.transpose(1, 2)
# {Bm, N, N}
attn = torch.bmm(q, kt)
attn = attn / self.temperature
attn = self.softmax2(attn)
attn = self.dropout(attn)
if m > 1:
# attn \in R^{Bm, N, N2} => R^{B, N, N2}
attn = (attn * pi).view(B, m, N, N2).sum(1)
output = torch.bmm(attn, v)
return output, attn
...@@ -28,8 +28,6 @@ class GlobalAvgPool2d(nn.Module): ...@@ -28,8 +28,6 @@ class GlobalAvgPool2d(nn.Module):
def forward(self, inputs): def forward(self, inputs):
return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1) return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)
class GramMatrix(nn.Module): class GramMatrix(nn.Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
......
# https://github.com/Randl/MobileNetV3-pytorch/blob/master/dropblock.py
import torch
import torch.nn.functional as F
from torch import nn
__all__ = ['DropBlock2D', 'reset_dropblock']
class DropBlock2D(nn.Module):
r"""Randomly zeroes 2D spatial blocks of the input tensor.
As described in the paper
`DropBlock: A regularization method for convolutional networks`_ ,
dropping whole blocks of feature map allows to remove semantic
information as compared to regular dropout.
Args:
drop_prob (float): probability of an element to be dropped.
block_size (int): size of the block to drop
Shape:
- Input: `(N, C, H, W)`
- Output: `(N, C, H, W)`
.. _DropBlock: A regularization method for convolutional networks:
https://arxiv.org/abs/1810.12890
"""
def __init__(self, drop_prob, block_size, share_channel=False):
super(DropBlock2D, self).__init__()
self.register_buffer('i', torch.zeros(1, dtype=torch.int64))
self.register_buffer('drop_prob', drop_prob * torch.ones(1, dtype=torch.float32))
self.inited = False
self.step_size = 0.0
self.start_step = 0
self.nr_steps = 0
self.block_size = block_size
self.share_channel = share_channel
def reset(self):
"""stop DropBlock"""
self.inited = True
self.i[0] = 0
self.drop_prob = 0.0
def reset_steps(self, start_step, nr_steps, start_value=0, stop_value=None):
self.inited = True
stop_value = self.drop_prob.item() if stop_value is None else stop_value
self.i[0] = 0
self.drop_prob[0] = start_value
self.step_size = (stop_value - start_value) / nr_steps
self.nr_steps = nr_steps
self.start_step = start_step
def forward(self, x):
if not self.training or self.drop_prob.item() == 0.:
return x
else:
self.step()
# get gamma value
gamma = self._compute_gamma(x)
# sample mask and place on input device
if self.share_channel:
mask = (torch.rand(x.shape[0], *x.shape[2:], device=x.device, dtype=x.dtype) < gamma).squeeze(1)
else:
mask = (torch.rand(*x.shape, device=x.device, dtype=x.dtype) < gamma)
# compute block mask
block_mask, keeped = self._compute_block_mask(mask)
# apply block mask
out = x * block_mask
# scale output
out = out * (block_mask.numel() / keeped).to(out)
return out
def _compute_block_mask(self, mask):
block_mask = F.max_pool2d(mask,
kernel_size=(self.block_size, self.block_size),
stride=(1, 1),
padding=self.block_size // 2)
keeped = block_mask.numel() - block_mask.sum().to(torch.float32)
block_mask = 1 - block_mask
return block_mask, keeped
def _compute_gamma(self, x):
_, c, h, w = x.size()
gamma = self.drop_prob.item() / (self.block_size ** 2) * (h * w) / \
((w - self.block_size + 1) * (h - self.block_size + 1))
return gamma
def step(self):
assert self.inited
idx = self.i.item()
if idx > self.start_step and idx < self.start_step + self.nr_steps:
self.drop_prob += self.step_size
self.i += 1
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
idx_key = prefix + 'i'
drop_prob_key = prefix + 'drop_prob'
if idx_key not in state_dict:
state_dict[idx_key] = torch.zeros(1, dtype=torch.int64)
if idx_key not in drop_prob_key:
state_dict[drop_prob_key] = torch.ones(1, dtype=torch.float32)
super(DropBlock2D, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""overwrite save method"""
pass
def extra_repr(self):
return 'drop_prob={}, step_size={}'.format(self.drop_prob, self.step_size)
def reset_dropblock(start_step, nr_steps, start_value, stop_value, m):
"""
Example:
from functools import partial
apply_drop_prob = partial(reset_dropblock, 0, epochs*iters_per_epoch, 0.0, 0.1)
net.apply(apply_drop_prob)
"""
if isinstance(m, DropBlock2D):
print('reseting dropblock')
m.reset_steps(start_step, nr_steps, start_value, stop_value)
...@@ -17,7 +17,8 @@ from torch.nn.modules.utils import _pair ...@@ -17,7 +17,8 @@ from torch.nn.modules.utils import _pair
from ..functions import scaled_l2, aggregate, pairwise_cosine from ..functions import scaled_l2, aggregate, pairwise_cosine
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d'] __all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d',
'EncodingCosine']
class Encoding(Module): class Encoding(Module):
r""" r"""
...@@ -304,3 +305,43 @@ class UpsampleConv2d(Module): ...@@ -304,3 +305,43 @@ class UpsampleConv2d(Module):
out = F.conv2d(input, self.weight, self.bias, self.stride, out = F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups) self.padding, self.dilation, self.groups)
return F.pixel_shuffle(out, self.scale_factor) return F.pixel_shuffle(out, self.scale_factor)
# Experimental
class EncodingCosine(Module):
def __init__(self, D, K):
super(EncodingCosine, self).__init__()
# init codewords and smoothing factor
self.D, self.K = D, K
self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
#self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.reset_params()
def reset_params(self):
std1 = 1./((self.K*self.D)**(1/2))
self.codewords.data.uniform_(-std1, std1)
#self.scale.data.uniform_(-1, 0)
def forward(self, X):
# input X is a 4D tensor
assert(X.size(1) == self.D)
if X.dim() == 3:
# BxDxN
B, D = X.size(0), self.D
X = X.transpose(1, 2).contiguous()
elif X.dim() == 4:
# BxDxHxW
B, D = X.size(0), self.D
X = X.view(B, D, -1).transpose(1, 2).contiguous()
else:
raise RuntimeError('Encoding Layer unknown input dims!')
# assignment weights NxKxD
L = pairwise_cosine(X, self.codewords)
A = F.softmax(L, dim=2)
# aggregate
E = aggregate(A, X, self.codewords)
return E
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \
+ str(self.D) + ')'
...@@ -2,128 +2,60 @@ import torch ...@@ -2,128 +2,60 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np
__all__ = ['SegmentationLosses', 'OhemCrossEntropy2d', 'OHEMSegmentationLosses']
class SegmentationLosses(nn.CrossEntropyLoss): __all__ = ['LabelSmoothing', 'NLLMultiLabelSmooth', 'SegmentationLosses']
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
@staticmethod class LabelSmoothing(nn.Module):
def _get_batch_label_vector(target, nclass): """
# target is a 3D Variable BxHxW, output is 2D BxnClass NLL loss with label smoothing.
batch = target.size(0) """
tvect = Variable(torch.zeros(batch, nclass)) def __init__(self, smoothing=0.1):
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
# adapted from https://github.com/PkuRainBow/OCNet/blob/master/utils/loss.py
class OhemCrossEntropy2d(nn.Module):
def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True):
super(OhemCrossEntropy2d, self).__init__()
self.ignore_label = ignore_label
self.thresh = float(thresh)
self.min_kept = int(min_kept)
if use_weight:
print("w/ class balance")
weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
1.0865, 1.1529, 1.0507])
self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
else:
print("w/o class balance")
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
def forward(self, predict, target, weight=None):
""" """
Args: Constructor for the LabelSmoothing module.
predict:(n, c, h, w) :param smoothing: label smoothing factor
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
""" """
assert not target.requires_grad super(LabelSmoothing, self).__init__()
assert predict.dim() == 4 self.confidence = 1.0 - smoothing
assert target.dim() == 3 self.smoothing = smoothing
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
n, c, h, w = predict.size() def forward(self, x, target):
input_label = target.data.cpu().numpy().ravel().astype(np.int32) logprobs = torch.nn.functional.log_softmax(x, dim=-1)
x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0).reshape((1, -1)))
input_prob /= input_prob.sum(axis=0).reshape((1, -1))
valid_flag = input_label != self.ignore_label nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
valid_inds = np.where(valid_flag)[0] nll_loss = nll_loss.squeeze(1)
label = input_label[valid_flag] smooth_loss = -logprobs.mean(dim=-1)
num_valid = valid_flag.sum() loss = self.confidence * nll_loss + self.smoothing * smooth_loss
if self.min_kept >= num_valid: return loss.mean()
print('Labels: {}'.format(num_valid))
elif num_valid > 0:
prob = input_prob[:,valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = pred.argsort()
threshold_index = index[ min(len(index), self.min_kept) - 1 ]
if pred[threshold_index] > self.thresh:
threshold = pred[threshold_index]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
label = input_label[valid_inds].copy() class NLLMultiLabelSmooth(nn.Module):
input_label.fill(self.ignore_label) def __init__(self, smoothing = 0.1):
input_label[valid_inds] = label super(NLLMultiLabelSmooth, self).__init__()
valid_flag_new = input_label != self.ignore_label self.confidence = 1.0 - smoothing
# print(np.sum(valid_flag_new)) self.smoothing = smoothing
target = Variable(torch.from_numpy(input_label.reshape(target.size())).long().cuda())
return self.criterion(predict, target) def forward(self, x, target):
if self.training:
x = x.float()
target = target.float()
logprobs = torch.nn.functional.log_softmax(x, dim = -1)
nll_loss = -logprobs * target
nll_loss = nll_loss.sum(-1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
else:
return torch.nn.functional.cross_entropy(x, target)
class OHEMSegmentationLosses(OhemCrossEntropy2d): class SegmentationLosses(nn.CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss""" """2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1, def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None, aux=False, aux_weight=0.4, weight=None,
ignore_index=-1): ignore_index=-1):
super(OHEMSegmentationLosses, self).__init__(ignore_index) super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss self.se_loss = se_loss
self.aux = aux self.aux = aux
self.nclass = nclass self.nclass = nclass
...@@ -133,23 +65,23 @@ class OHEMSegmentationLosses(OhemCrossEntropy2d): ...@@ -133,23 +65,23 @@ class OHEMSegmentationLosses(OhemCrossEntropy2d):
def forward(self, *inputs): def forward(self, *inputs):
if not self.se_loss and not self.aux: if not self.se_loss and not self.aux:
return super(OHEMSegmentationLosses, self).forward(*inputs) return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss: elif not self.se_loss:
pred1, pred2, target = tuple(inputs) pred1, pred2, target = tuple(inputs)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target) loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target) loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2 return loss1 + self.aux_weight * loss2
elif not self.aux: elif not self.aux:
pred, se_pred, target = tuple(inputs) pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred) se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(OHEMSegmentationLosses, self).forward(pred, target) loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target) loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2 return loss1 + self.se_weight * loss2
else: else:
pred1, se_pred, pred2, target = tuple(inputs) pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1) se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target) loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target) loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target) loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3 return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
......
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Rectify Module"""
import warnings
import torch
from torch.nn import Conv2d
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from ..functions import rectify
__all__ = ['RFConv2d']
class RFConv2d(Conv2d):
"""Rectified Convolution
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros',
average_mode=False):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
self.rectify = average_mode or (padding[0] > 0 or padding[1] > 0)
self.average = average_mode
super(RFConv2d, self).__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups,
bias=bias, padding_mode=padding_mode)
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
output = self._conv_forward(input, self.weight)
if self.rectify:
output = rectify(output, input, self.kernel_size, self.stride,
self.padding, self.dilation, self.average)
return output
def extra_repr(self):
return super().extra_repr() + ', rectify={}, average_mode={}'. \
format(self.rectify, self.average)
"""Split-Attention"""
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU
from torch.nn.modules.utils import _pair
from ..nn import RFConv2d
from .dropblock import DropBlock2D
__all__ = ['SKConv2d']
class SplAtConv2d(Module):
"""Split-Attention Conv2d
"""
def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0),
dilation=(1, 1), groups=1, bias=True,
radix=2, reduction_factor=4,
rectify=False, rectify_avg=False, norm_layer=None,
dropblock_prob=0.0, **kwargs):
super(SplAtConv2d, self).__init__()
padding = _pair(padding)
self.rectify = rectify and (padding[0] > 0 or padding[1] > 0)
self.rectify_avg = rectify_avg
inter_channels = max(in_channels*radix//reduction_factor, 32)
self.radix = radix
self.cardinality = groups
self.channels = channels
self.dropblock_prob = dropblock_prob
if self.rectify:
self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs)
else:
self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation,
groups=groups*radix, bias=bias, **kwargs)
self.use_bn = norm_layer is not None
self.bn0 = norm_layer(channels*radix)
self.relu = ReLU(inplace=True)
self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality)
self.bn1 = norm_layer(inter_channels)
self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality)
if dropblock_prob > 0.0:
self.dropblock = DropBlock2D(dropblock_prob, 3)
def forward(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn0(x)
if self.dropblock_prob > 0.0:
x = self.dropblock(x)
x = self.relu(x)
batch, channel = x.shape[:2]
if self.radix > 1:
splited = torch.split(x, channel//self.radix, dim=1)
gap = sum(splited)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
if self.use_bn:
gap = self.bn1(gap)
gap = self.relu(gap)
atten = self.fc2(gap).view((batch, self.radix, self.channels))
if self.radix > 1:
atten = F.softmax(atten, dim=1).view(batch, -1, 1, 1)
else:
atten = F.sigmoid(atten, dim=1).view(batch, -1, 1, 1)
if self.radix > 1:
atten = torch.split(atten, channel//self.radix, dim=1)
out = sum([att*split for (att, split) in zip(atten, splited)])
else:
out = atten * x
return out.contiguous()
...@@ -22,10 +22,9 @@ from ..utils.misc import EncodingDeprecationWarning ...@@ -22,10 +22,9 @@ from ..utils.misc import EncodingDeprecationWarning
from ..functions import * from ..functions import *
__all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] __all__ = ['DistSyncBatchNorm', 'SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d']
class DistSyncBatchNorm(_BatchNorm):
class SyncBatchNorm(_BatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN) r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device (GPU). Standard BN [1]_ implementation only normalize the data within each device (GPU).
...@@ -71,10 +70,86 @@ class SyncBatchNorm(_BatchNorm): ...@@ -71,10 +70,86 @@ class SyncBatchNorm(_BatchNorm):
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
Examples:
>>> m = DistSyncBatchNorm(100)
>>> net = torch.nn.parallel.DistributedDataParallel(m)
>>> output = net(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, process_group=None):
super(DistSyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True, track_running_stats=True)
self.process_group = process_group
def forward(self, x):
need_sync = self.training or not self.track_running_stats
process_group = None
if need_sync:
process_group = torch.distributed.group.WORLD
if self.process_group:
process_group = self.process_group
world_size = torch.distributed.get_world_size(process_group)
need_sync = world_size > 1
# Resize the input to (B, C, -1).
input_shape = x.size()
x = x.view(input_shape[0], self.num_features, -1)
#def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum, training, process_group):
y = dist_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
self.eps, self.momentum, self.training, process_group)
#_var = _exs - _ex ** 2
#running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
#running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
return y.view(input_shape)
class SyncBatchNorm(_BatchNorm):
r"""Cross-GPU Synchronized Batch normalization (SyncBN)
Standard BN [1]_ implementation only normalize the data within each device (GPU).
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_ .
Please see the design idea in the `notes <./notes/syncbn.html>`_.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
The mean and standard-deviation are calculated per-channel over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
sync: a boolean value that when set to ``True``, synchronize across
different gpus. Default: ``True``
activation : str
Name of the activation functions, one of: `leaky_relu` or `none`.
slope : float
Negative slope for the `leaky_relu` activation.
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples: Examples:
>>> m = SyncBatchNorm(100) >>> m = SyncBatchNorm(100)
>>> net = torch.nn.DataParallel(m) >>> net = torch.nn.DataParallel(m)
>>> output = net(input) >>> output = net(input)
>>> # for Inpace ABN
>>> ABN = partial(SyncBatchNorm, activation='leaky_relu', slope=0.01, sync=True, inplace=True)
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none", slope=0.01, def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation="none", slope=0.01,
......
import torch from .transforms import *
from torchvision.transforms import * from .get_transform import get_transform
def get_transform(dataset, large_test_crop=False):
normalize = Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if dataset == 'imagenet':
transform_train = Compose([
Resize(256),
RandomResizedCrop(224),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
normalize,
])
if large_test_crop:
transform_val = Compose([
Resize(366),
CenterCrop(320),
ToTensor(),
normalize,
])
else:
transform_val = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
normalize,
])
elif dataset == 'minc':
transform_train = Compose([
Resize(256),
RandomResizedCrop(224),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
normalize,
])
transform_val = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
normalize,
])
elif dataset == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
return transform_train, transform_val
_imagenet_pca = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
'eigvec': torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone()\
.mul(alpha.view(1, 3).expand(3, 3))\
.mul(self.eigval.view(1, 3).expand(3, 3))\
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# code adapted from:
# https://github.com/kakaobrain/fast-autoaugment
# https://github.com/rpmcruz/autoaugment
import math
import random
import numpy as np
from collections import defaultdict
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
RESAMPLE_MODE=PIL.Image.BICUBIC#PIL.Image.BILINEAR#
RANDOM_MIRROR = True
def ShearX(img, v, resample=RESAMPLE_MODE): # [-0.3, 0.3]
assert -0.3 <= v <= 0.3
if RANDOM_MIRROR and random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0),
resample=resample)
def ShearY(img, v, resample=RESAMPLE_MODE): # [-0.3, 0.3]
assert -0.3 <= v <= 0.3
if RANDOM_MIRROR and random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0),
resample=resample)
def TranslateX(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if RANDOM_MIRROR and random.random() > 0.5:
v = -v
v = v * img.size[0]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0),
resample=resample)
def TranslateY(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45]
assert -0.45 <= v <= 0.45
if RANDOM_MIRROR and random.random() > 0.5:
v = -v
v = v * img.size[1]
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v),
resample=resample)
def TranslateXabs(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0),
resample=resample)
def TranslateYabs(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45]
assert 0 <= v
if random.random() > 0.5:
v = -v
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v),
resample=resample)
def Rotate(img, v): # [-30, 30]
assert -30 <= v <= 30
if RANDOM_MIRROR and random.random() > 0.5:
v = -v
return img.rotate(v)
def AutoContrast(img, _):
return PIL.ImageOps.autocontrast(img)
def Invert(img, _):
return PIL.ImageOps.invert(img)
def Equalize(img, _):
return PIL.ImageOps.equalize(img)
def Flip(img, _): # not from the paper
return PIL.ImageOps.mirror(img)
def Solarize(img, v): # [0, 256]
assert 0 <= v <= 256
return PIL.ImageOps.solarize(img, v)
def SolarizeAdd(img, addition=0, threshold=128):
img_np = np.array(img).astype(np.int)
img_np = img_np + addition
img_np = np.clip(img_np, 0, 255)
img_np = img_np.astype(np.uint8)
img = PIL.Image.fromarray(img_np)
return PIL.ImageOps.solarize(img, threshold)
def Posterize(img, v): # [4, 8]
#assert 4 <= v <= 8
v = int(v)
return PIL.ImageOps.posterize(img, v)
def Contrast(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Contrast(img).enhance(v)
def Color(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Color(img).enhance(v)
def Brightness(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Brightness(img).enhance(v)
def Sharpness(img, v): # [0.1,1.9]
assert 0.1 <= v <= 1.9
return PIL.ImageEnhance.Sharpness(img).enhance(v)
def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
# assert 0 <= v <= 20
if v < 0:
return img
w, h = img.size
x0 = np.random.uniform(w)
y0 = np.random.uniform(h)
x0 = int(max(0, x0 - v / 2.))
y0 = int(max(0, y0 - v / 2.))
x1 = min(w, x0 + v)
y1 = min(h, y0 + v)
xy = (x0, y0, x1, y1)
color = (125, 123, 114)
# color = (0, 0, 0)
img = img.copy()
PIL.ImageDraw.Draw(img).rectangle(xy, color)
return img
def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
assert 0.0 <= v <= 0.2
if v <= 0.:
return img
v = v * img.size[0]
return CutoutAbs(img, v)
def rand_augment_list(): # 16 oeprations and their ranges
l = [
(AutoContrast, 0, 1),
(Equalize, 0, 1),
(Invert, 0, 1),
(Rotate, 0, 30),
(Posterize, 0, 4),
(Solarize, 0, 256),
(SolarizeAdd, 0, 110),
(Color, 0.1, 1.9),
(Contrast, 0.1, 1.9),
(Brightness, 0.1, 1.9),
(Sharpness, 0.1, 1.9),
(ShearX, 0., 0.3),
(ShearY, 0., 0.3),
(CutoutAbs, 0, 40),
(TranslateXabs, 0., 100),
(TranslateYabs, 0., 100),
]
return l
class RandAugment(object):
def __init__(self, n, m):
self.n = n
self.m = m
self.augment_list = rand_augment_list()
def __call__(self, img):
ops = random.choices(self.augment_list, k=self.n)
for op, minval, maxval in ops:
if random.random() > random.uniform(0.2, 0.8):
continue
val = (float(self.m) / 30) * float(maxval - minval) + minval
img = op(img, val)
return img
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import torch
from torchvision.transforms import *
from .transforms import *
def get_transform(dataset, base_size=None, crop_size=224, rand_aug=False, etrans=True, **kwargs):
normalize = Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
base_size = base_size if base_size is not None else int(1.0 * crop_size / 0.875)
if dataset == 'imagenet':
train_transforms = []
val_transforms = []
if rand_aug:
from .autoaug import RandAugment
train_transforms.append(RandAugment(2, 12))
if etrans:
train_transforms.extend([
ERandomCrop(crop_size),
])
val_transforms.extend([
ECenterCrop(crop_size),
])
else:
train_transforms.extend([
RandomResizedCrop(crop_size),
])
val_transforms.extend([
Resize(base_size),
CenterCrop(crop_size),
])
train_transforms.extend([
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
normalize,
])
val_transforms.extend([
ToTensor(),
normalize,
])
transform_train = Compose(train_transforms)
transform_val = Compose(val_transforms)
elif dataset == 'minc':
transform_train = Compose([
Resize(base_size),
RandomResizedCrop(crop_size),
RandomHorizontalFlip(),
ColorJitter(0.4, 0.4, 0.4),
ToTensor(),
Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']),
normalize,
])
transform_val = Compose([
Resize(base_size),
CenterCrop(crop_size),
ToTensor(),
normalize,
])
elif dataset == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
return transform_train, transform_val
_imagenet_pca = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
'eigvec': torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import math
import random
from PIL import Image
from torchvision.transforms import Resize
__all__ = ['Lighting', 'ERandomCrop', 'ECenterCrop']
class Lighting(object):
"""Lighting noise(AlexNet - style PCA - based noise)"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone()\
.mul(alpha.view(1, 3).expand(3, 3))\
.mul(self.eigval.view(1, 3).expand(3, 3))\
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))
#https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py
class ERandomCrop:
def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3),
area_range=(0.1, 1.0), max_attempts=10):
assert 0.0 < min_covered
assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1]
assert 0 < area_range[0] <= area_range[1]
assert 1 <= max_attempts
self.imgsize = imgsize
self.min_covered = min_covered
self.aspect_ratio_range = aspect_ratio_range
self.area_range = area_range
self.max_attempts = max_attempts
self._fallback = ECenterCrop(imgsize)
self.resize_method = Resize((imgsize, imgsize), interpolation=Image.BICUBIC)
def __call__(self, img):
original_width, original_height = img.size
min_area = self.area_range[0] * (original_width * original_height)
max_area = self.area_range[1] * (original_width * original_height)
for _ in range(self.max_attempts):
aspect_ratio = random.uniform(*self.aspect_ratio_range)
height = int(round(math.sqrt(min_area / aspect_ratio)))
max_height = int(round(math.sqrt(max_area / aspect_ratio)))
if max_height * aspect_ratio > original_width:
max_height = (original_width + 0.5 - 1e-7) / aspect_ratio
max_height = int(max_height)
if max_height * aspect_ratio > original_width:
max_height -= 1
if max_height > original_height:
max_height = original_height
if height >= max_height:
height = max_height
height = int(round(random.uniform(height, max_height)))
width = int(round(height * aspect_ratio))
area = width * height
if area < min_area or area > max_area:
continue
if width > original_width or height > original_height:
continue
if area < self.min_covered * (original_width * original_height):
continue
if width == original_width and height == original_height:
return self._fallback(img)
x = random.randint(0, original_width - width)
y = random.randint(0, original_height - height)
img = img.crop((x, y, x + width, y + height))
return self.resize_method(img)
return self._fallback(img)
class ECenterCrop:
"""Crop the given PIL Image and resize it to desired size.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions
Returns:
PIL Image: Cropped image.
"""
def __init__(self, imgsize):
self.imgsize = imgsize
self.resize_method = Resize((imgsize, imgsize), interpolation=Image.BICUBIC)
def __call__(self, img):
image_width, image_height = img.size
image_short = min(image_width, image_height)
crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short
crop_height, crop_width = crop_size, crop_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))
img = img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
return self.resize_method(img)
...@@ -9,14 +9,10 @@ ...@@ -9,14 +9,10 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Util Tools""" """Encoding Util Tools"""
from .lr_scheduler import LR_Scheduler from .lr_scheduler import *
from .metrics import SegmentationMetric, batch_intersection_union, batch_pix_accuracy from .metrics import *
from .pallete import get_mask_pallete from .pallete import get_mask_pallete
from .train_helper import * from .train_helper import *
from .presets import load_image from .presets import load_image
from .files import * from .files import *
from .misc import * from .misc import *
__all__ = ['LR_Scheduler', 'batch_pix_accuracy', 'batch_intersection_union',
'save_checkpoint', 'download', 'mkdir', 'check_sha1', 'load_image',
'get_mask_pallete', 'get_selabel_vector', 'EMA']
...@@ -10,7 +10,10 @@ __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1'] ...@@ -10,7 +10,10 @@ __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1']
def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'):
"""Saves checkpoint to disk""" """Saves checkpoint to disk"""
directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname) if hasattr(args, 'backbone'):
directory = "runs/%s/%s/%s/%s/"%(args.dataset, args.model, args.backbone, args.checkname)
else:
directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname)
if not os.path.exists(directory): if not os.path.exists(directory):
os.makedirs(directory) os.makedirs(directory)
filename = directory + filename filename = directory + filename
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
import math import math
__all__ = ['LR_Scheduler', 'LR_Scheduler_Head']
class LR_Scheduler(object): class LR_Scheduler(object):
"""Learning Rate Scheduler """Learning Rate Scheduler
...@@ -29,36 +31,44 @@ class LR_Scheduler(object): ...@@ -29,36 +31,44 @@ class LR_Scheduler(object):
def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0,
lr_step=0, warmup_epochs=0): lr_step=0, warmup_epochs=0):
self.mode = mode self.mode = mode
print('Using {} LR Scheduler!'.format(self.mode)) print('Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs))
self.lr = base_lr
if mode == 'step': if mode == 'step':
assert lr_step assert lr_step
self.base_lr = base_lr
self.lr_step = lr_step self.lr_step = lr_step
self.iters_per_epoch = iters_per_epoch self.iters_per_epoch = iters_per_epoch
self.N = num_epochs * iters_per_epoch
self.epoch = -1 self.epoch = -1
self.warmup_iters = warmup_epochs * iters_per_epoch self.warmup_iters = warmup_epochs * iters_per_epoch
self.total_iters = (num_epochs - warmup_epochs) * iters_per_epoch
def __call__(self, optimizer, i, epoch, best_pred): def __call__(self, optimizer, i, epoch, best_pred):
T = epoch * self.iters_per_epoch + i T = epoch * self.iters_per_epoch + i
if self.mode == 'cos': # warm up lr schedule
lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) if self.warmup_iters > 0 and T < self.warmup_iters:
lr = self.base_lr * 1.0 * T / self.warmup_iters
elif self.mode == 'cos':
T = T - self.warmup_iters
lr = 0.5 * self.base_lr * (1 + math.cos(1.0 * T / self.total_iters * math.pi))
elif self.mode == 'poly': elif self.mode == 'poly':
lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) T = T - self.warmup_iters
lr = self.base_lr * pow((1 - 1.0 * T / self.total_iters), 0.9)
elif self.mode == 'step': elif self.mode == 'step':
lr = self.lr * (0.1 ** (epoch // self.lr_step)) lr = self.base_lr * (0.1 ** (epoch // self.lr_step))
else: else:
raise NotImplemented raise NotImplemented
# warm up lr schedule if epoch > self.epoch and (epoch == 0 or best_pred > 0.0):
if self.warmup_iters > 0 and T < self.warmup_iters: print('\n=>Epoch %i, learning rate = %.4f, \
lr = lr * 1.0 * T / self.warmup_iters
if epoch > self.epoch:
print('\n=>Epoches %i, learning rate = %.4f, \
previous best = %.4f' % (epoch, lr, best_pred)) previous best = %.4f' % (epoch, lr, best_pred))
self.epoch = epoch self.epoch = epoch
assert lr >= 0 assert lr >= 0
self._adjust_learning_rate(optimizer, lr) self._adjust_learning_rate(optimizer, lr)
def _adjust_learning_rate(self, optimizer, lr):
for i in range(len(optimizer.param_groups)):
optimizer.param_groups[i]['lr'] = lr
class LR_Scheduler_Head(LR_Scheduler):
"""Incease the additional head LR to be 10 times"""
def _adjust_learning_rate(self, optimizer, lr): def _adjust_learning_rate(self, optimizer, lr):
if len(optimizer.param_groups) == 1: if len(optimizer.param_groups) == 1:
optimizer.param_groups[0]['lr'] = lr optimizer.param_groups[0]['lr'] = lr
......
...@@ -12,6 +12,25 @@ import threading ...@@ -12,6 +12,25 @@ import threading
import numpy as np import numpy as np
import torch import torch
__all__ = ['accuracy', 'SegmentationMetric', 'batch_intersection_union', 'batch_pix_accuracy',
'pixel_accuracy', 'intersection_and_union']
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class SegmentationMetric(object): class SegmentationMetric(object):
"""Computes pixAcc and mIoU metric scroes """Computes pixAcc and mIoU metric scroes
""" """
......
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import warnings import warnings
__all__ = ['EncodingDeprecationWarning'] __all__ = ['AverageMeter', 'EncodingDeprecationWarning']
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
#self.val = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
#self.val = val
self.sum += val * n
self.count += n
@property
def avg(self):
avg = 0 if self.count == 0 else self.sum / self.count
return avg
class EncodingDeprecationWarning(DeprecationWarning): class EncodingDeprecationWarning(DeprecationWarning):
pass pass
......
...@@ -8,13 +8,45 @@ ...@@ -8,13 +8,45 @@
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
#from ..nn import SyncBatchNorm #from ..nn import SyncBatchNorm
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
__all__ = ['get_selabel_vector'] __all__ = ['MixUpWrapper', 'get_selabel_vector']
class MixUpWrapper(object):
def __init__(self, alpha, num_classes, dataloader, device):
self.alpha = alpha
self.dataloader = dataloader
self.num_classes = num_classes
self.device = device
def mixup_loader(self, loader):
def mixup(alpha, num_classes, data, target):
with torch.no_grad():
bs = data.size(0)
c = np.random.beta(alpha, alpha)
perm = torch.randperm(bs).cuda()
md = c * data + (1-c) * data[perm, :]
mt = c * target + (1-c) * target[perm, :]
return md, mt
for input, target in loader:
input, target = input.cuda(self.device), target.cuda(self.device)
target = torch.nn.functional.one_hot(target, self.num_classes)
i, t = mixup(self.alpha, self.num_classes, input, target)
yield i, t
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.mixup_loader(self.dataloader)
def get_selabel_vector(target, nclass): def get_selabel_vector(target, nclass):
r"""Get SE-Loss Label in a batch r"""Get SE-Loss Label in a batch
...@@ -34,45 +66,3 @@ def get_selabel_vector(target, nclass): ...@@ -34,45 +66,3 @@ def get_selabel_vector(target, nclass):
vect = hist>0 vect = hist>0
tvect[i] = vect tvect[i] = vect
return tvect return tvect
class EMA():
r""" Use moving avg for the models.
Examples:
>>> ema = EMA(0.999)
>>> for name, param in model.named_parameters():
>>> if param.requires_grad:
>>> ema.register(name, param.data)
>>>
>>> # during training:
>>> # optimizer.step()
>>> for name, param in model.named_parameters():
>>> # Sometime I also use the moving average of non-trainable parameters, just according to the model structure
>>> if param.requires_grad:
>>> ema(name, param.data)
>>>
>>> # during eval or test
>>> import copy
>>> model_test = copy.deepcopy(model)
>>> for name, param in model_test.named_parameters():
>>> # Sometime I also use the moving average of non-trainable parameters, just according to the model structure
>>> if param.requires_grad:
>>> param.data = ema.get(name)
>>> # Then use model_test for eval.
"""
def __init__(self, momentum):
self.momentum = momentum
self.shadow = {}
def register(self, name, val):
self.shadow[name] = val.clone()
def __call__(self, name, x):
assert name in self.shadow
new_average = (1.0 - self.momentum) * x + self.momentum * self.shadow[name]
self.shadow[name] = new_average.clone()
return new_average
def get(self, name):
assert name in self.shadow
return self.shadow[name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment