Unverified Commit b872eb8c authored by Hang Zhang's avatar Hang Zhang Committed by GitHub
Browse files

ResNeSt plus (#256)

parent 5a1e3fbc
......@@ -12,4 +12,8 @@
from .encoding import *
from .syncbn import *
from .customize import *
from .attention import *
from .loss import *
from .rectify import *
from .splat import SplAtConv2d
from .dropblock import *
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2018
###########################################################################
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from .syncbn import SyncBatchNorm
__all__ = ['ACFModule', 'MixtureOfSoftMaxACF']
class ACFModule(nn.Module):
""" Multi-Head Attention module """
def __init__(self, n_head, n_mix, d_model, d_k, d_v, norm_layer=SyncBatchNorm,
kq_transform='conv', value_transform='conv',
pooling=True, concat=False, dropout=0.1):
super(ACFModule, self).__init__()
self.n_head = n_head
self.n_mix = n_mix
self.d_k = d_k
self.d_v = d_v
self.pooling = pooling
self.concat = concat
if self.pooling:
self.pool = nn.AvgPool2d(3, 2, 1, count_include_pad=False)
if kq_transform == 'conv':
self.conv_qs = nn.Conv2d(d_model, n_head*d_k, 1)
nn.init.normal_(self.conv_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
elif kq_transform == 'ffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=1, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
elif kq_transform == 'dffn':
self.conv_qs = nn.Sequential(
nn.Conv2d(d_model, n_head*d_k, 3, padding=4, dilation=4, bias=False),
norm_layer(n_head*d_k),
nn.ReLU(True),
nn.Conv2d(n_head*d_k, n_head*d_k, 1),
)
nn.init.normal_(self.conv_qs[-1].weight, mean=0, std=np.sqrt(1.0 / d_k))
else:
raise NotImplemented
#self.conv_ks = nn.Conv2d(d_model, n_head*d_k, 1)
self.conv_ks = self.conv_qs
if value_transform == 'conv':
self.conv_vs = nn.Conv2d(d_model, n_head*d_v, 1)
else:
raise NotImplemented
#nn.init.normal_(self.conv_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.conv_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = MixtureOfSoftMaxACF(n_mix=n_mix, d_k=d_k)
self.conv = nn.Conv2d(n_head*d_v, d_model, 1, bias=False)
self.norm_layer = norm_layer(d_model)
def forward(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
output, attn = self.attention(qt, kt, vt)
output = output.transpose(1, 2).contiguous().view(b_, n_head*d_v, h_, w_)
output = self.conv(output)
if self.concat:
output = torch.cat((self.norm_layer(output), residual), 1)
else:
output = self.norm_layer(output) + residual
return output
def demo(self, x):
residual = x
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
b_, c_, h_, w_ = x.size()
if self.pooling:
qt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
kt = self.conv_ks(self.pool(x)).view(b_*n_head, d_k, h_*w_//4)
vt = self.conv_vs(self.pool(x)).view(b_*n_head, d_v, h_*w_//4)
else:
kt = self.conv_ks(x).view(b_*n_head, d_k, h_*w_)
qt = kt
vt = self.conv_vs(x).view(b_*n_head, d_v, h_*w_)
_, attn = self.attention(qt, kt, vt)
attn.view(b_, n_head, h_*w_, -1)
return attn
def extra_repr(self):
return 'n_head={}, n_mix={}, d_k={}, pooling={}' \
.format(self.n_head, self.n_mix, self.d_k, self.pooling)
class MixtureOfSoftMaxACF(nn.Module):
""""Mixture of SoftMax"""
def __init__(self, n_mix, d_k, attn_dropout=0.1):
super(MixtureOfSoftMaxACF, self).__init__()
self.temperature = np.power(d_k, 0.5)
self.n_mix = n_mix
self.att_drop = attn_dropout
self.dropout = nn.Dropout(attn_dropout)
self.softmax1 = nn.Softmax(dim=1)
self.softmax2 = nn.Softmax(dim=2)
self.d_k = d_k
if n_mix > 1:
self.weight = nn.Parameter(torch.Tensor(n_mix, d_k))
std = np.power(n_mix, -0.5)
self.weight.data.uniform_(-std, std)
def forward(self, qt, kt, vt):
B, d_k, N = qt.size()
m = self.n_mix
assert d_k == self.d_k
d = d_k // m
if m > 1:
# \bar{v} \in R^{B, d_k, 1}
bar_qt = torch.mean(qt, 2, True)
# pi \in R^{B, m, 1}
pi = self.softmax1(torch.matmul(self.weight, bar_qt)).view(B*m, 1, 1)
# reshape for n_mix
q = qt.view(B*m, d, N).transpose(1, 2)
N2 = kt.size(2)
kt = kt.view(B*m, d, N2)
v = vt.transpose(1, 2)
# {Bm, N, N}
attn = torch.bmm(q, kt)
attn = attn / self.temperature
attn = self.softmax2(attn)
attn = self.dropout(attn)
if m > 1:
# attn \in R^{Bm, N, N2} => R^{B, N, N2}
attn = (attn * pi).view(B, m, N, N2).sum(1)
output = torch.bmm(attn, v)
return output, attn
......@@ -28,8 +28,6 @@ class GlobalAvgPool2d(nn.Module):
def forward(self, inputs):
return F.adaptive_avg_pool2d(inputs, 1).view(inputs.size(0), -1)
class GramMatrix(nn.Module):
r""" Gram Matrix for a 4D convolutional featuremaps as a mini-batch
......
# https://github.com/Randl/MobileNetV3-pytorch/blob/master/dropblock.py
import torch
import torch.nn.functional as F
from torch import nn
__all__ = ['DropBlock2D', 'reset_dropblock']
class DropBlock2D(nn.Module):
r"""Randomly zeroes 2D spatial blocks of the input tensor.
As described in the paper
`DropBlock: A regularization method for convolutional networks`_ ,
dropping whole blocks of feature map allows to remove semantic
information as compared to regular dropout.
Args:
drop_prob (float): probability of an element to be dropped.
block_size (int): size of the block to drop
Shape:
- Input: `(N, C, H, W)`
- Output: `(N, C, H, W)`
.. _DropBlock: A regularization method for convolutional networks:
https://arxiv.org/abs/1810.12890
"""
def __init__(self, drop_prob, block_size, share_channel=False):
super(DropBlock2D, self).__init__()
self.register_buffer('i', torch.zeros(1, dtype=torch.int64))
self.register_buffer('drop_prob', drop_prob * torch.ones(1, dtype=torch.float32))
self.inited = False
self.step_size = 0.0
self.start_step = 0
self.nr_steps = 0
self.block_size = block_size
self.share_channel = share_channel
def reset(self):
"""stop DropBlock"""
self.inited = True
self.i[0] = 0
self.drop_prob = 0.0
def reset_steps(self, start_step, nr_steps, start_value=0, stop_value=None):
self.inited = True
stop_value = self.drop_prob.item() if stop_value is None else stop_value
self.i[0] = 0
self.drop_prob[0] = start_value
self.step_size = (stop_value - start_value) / nr_steps
self.nr_steps = nr_steps
self.start_step = start_step
def forward(self, x):
if not self.training or self.drop_prob.item() == 0.:
return x
else:
self.step()
# get gamma value
gamma = self._compute_gamma(x)
# sample mask and place on input device
if self.share_channel:
mask = (torch.rand(x.shape[0], *x.shape[2:], device=x.device, dtype=x.dtype) < gamma).squeeze(1)
else:
mask = (torch.rand(*x.shape, device=x.device, dtype=x.dtype) < gamma)
# compute block mask
block_mask, keeped = self._compute_block_mask(mask)
# apply block mask
out = x * block_mask
# scale output
out = out * (block_mask.numel() / keeped).to(out)
return out
def _compute_block_mask(self, mask):
block_mask = F.max_pool2d(mask,
kernel_size=(self.block_size, self.block_size),
stride=(1, 1),
padding=self.block_size // 2)
keeped = block_mask.numel() - block_mask.sum().to(torch.float32)
block_mask = 1 - block_mask
return block_mask, keeped
def _compute_gamma(self, x):
_, c, h, w = x.size()
gamma = self.drop_prob.item() / (self.block_size ** 2) * (h * w) / \
((w - self.block_size + 1) * (h - self.block_size + 1))
return gamma
def step(self):
assert self.inited
idx = self.i.item()
if idx > self.start_step and idx < self.start_step + self.nr_steps:
self.drop_prob += self.step_size
self.i += 1
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
idx_key = prefix + 'i'
drop_prob_key = prefix + 'drop_prob'
if idx_key not in state_dict:
state_dict[idx_key] = torch.zeros(1, dtype=torch.int64)
if idx_key not in drop_prob_key:
state_dict[drop_prob_key] = torch.ones(1, dtype=torch.float32)
super(DropBlock2D, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""overwrite save method"""
pass
def extra_repr(self):
return 'drop_prob={}, step_size={}'.format(self.drop_prob, self.step_size)
def reset_dropblock(start_step, nr_steps, start_value, stop_value, m):
"""
Example:
from functools import partial
apply_drop_prob = partial(reset_dropblock, 0, epochs*iters_per_epoch, 0.0, 0.1)
net.apply(apply_drop_prob)
"""
if isinstance(m, DropBlock2D):
print('reseting dropblock')
m.reset_steps(start_step, nr_steps, start_value, stop_value)
......@@ -17,7 +17,8 @@ from torch.nn.modules.utils import _pair
from ..functions import scaled_l2, aggregate, pairwise_cosine
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d']
__all__ = ['Encoding', 'EncodingDrop', 'Inspiration', 'UpsampleConv2d',
'EncodingCosine']
class Encoding(Module):
r"""
......@@ -304,3 +305,43 @@ class UpsampleConv2d(Module):
out = F.conv2d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return F.pixel_shuffle(out, self.scale_factor)
# Experimental
class EncodingCosine(Module):
def __init__(self, D, K):
super(EncodingCosine, self).__init__()
# init codewords and smoothing factor
self.D, self.K = D, K
self.codewords = Parameter(torch.Tensor(K, D), requires_grad=True)
#self.scale = Parameter(torch.Tensor(K), requires_grad=True)
self.reset_params()
def reset_params(self):
std1 = 1./((self.K*self.D)**(1/2))
self.codewords.data.uniform_(-std1, std1)
#self.scale.data.uniform_(-1, 0)
def forward(self, X):
# input X is a 4D tensor
assert(X.size(1) == self.D)
if X.dim() == 3:
# BxDxN
B, D = X.size(0), self.D
X = X.transpose(1, 2).contiguous()
elif X.dim() == 4:
# BxDxHxW
B, D = X.size(0), self.D
X = X.view(B, D, -1).transpose(1, 2).contiguous()
else:
raise RuntimeError('Encoding Layer unknown input dims!')
# assignment weights NxKxD
L = pairwise_cosine(X, self.codewords)
A = F.softmax(L, dim=2)
# aggregate
E = aggregate(A, X, self.codewords)
return E
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.D) + '=>' + str(self.K) + 'x' \
+ str(self.D) + ')'
......@@ -2,128 +2,60 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
__all__ = ['SegmentationLosses', 'OhemCrossEntropy2d', 'OHEMSegmentationLosses']
class SegmentationLosses(nn.CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
self.se_weight = se_weight
self.aux_weight = aux_weight
self.bceloss = nn.BCELoss(weight)
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
__all__ = ['LabelSmoothing', 'NLLMultiLabelSmooth', 'SegmentationLosses']
@staticmethod
def _get_batch_label_vector(target, nclass):
# target is a 3D Variable BxHxW, output is 2D BxnClass
batch = target.size(0)
tvect = Variable(torch.zeros(batch, nclass))
for i in range(batch):
hist = torch.histc(target[i].cpu().data.float(),
bins=nclass, min=0,
max=nclass-1)
vect = hist>0
tvect[i] = vect
return tvect
# adapted from https://github.com/PkuRainBow/OCNet/blob/master/utils/loss.py
class OhemCrossEntropy2d(nn.Module):
def __init__(self, ignore_label=-1, thresh=0.7, min_kept=100000, use_weight=True):
super(OhemCrossEntropy2d, self).__init__()
self.ignore_label = ignore_label
self.thresh = float(thresh)
self.min_kept = int(min_kept)
if use_weight:
print("w/ class balance")
weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
1.0865, 1.1529, 1.0507])
self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)
else:
print("w/o class balance")
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label)
def forward(self, predict, target, weight=None):
class LabelSmoothing(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Args:
predict:(n, c, h, w)
target:(n, h, w)
weight (Tensor, optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
assert not target.requires_grad
assert predict.dim() == 4
assert target.dim() == 3
assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
n, c, h, w = predict.size()
input_label = target.data.cpu().numpy().ravel().astype(np.int32)
x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0).reshape((1, -1)))
input_prob /= input_prob.sum(axis=0).reshape((1, -1))
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
valid_flag = input_label != self.ignore_label
valid_inds = np.where(valid_flag)[0]
label = input_label[valid_flag]
num_valid = valid_flag.sum()
if self.min_kept >= num_valid:
print('Labels: {}'.format(num_valid))
elif num_valid > 0:
prob = input_prob[:,valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = pred.argsort()
threshold_index = index[ min(len(index), self.min_kept) - 1 ]
if pred[threshold_index] > self.thresh:
threshold = pred[threshold_index]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
label = input_label[valid_inds].copy()
input_label.fill(self.ignore_label)
input_label[valid_inds] = label
valid_flag_new = input_label != self.ignore_label
# print(np.sum(valid_flag_new))
target = Variable(torch.from_numpy(input_label.reshape(target.size())).long().cuda())
class NLLMultiLabelSmooth(nn.Module):
def __init__(self, smoothing = 0.1):
super(NLLMultiLabelSmooth, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
return self.criterion(predict, target)
def forward(self, x, target):
if self.training:
x = x.float()
target = target.float()
logprobs = torch.nn.functional.log_softmax(x, dim = -1)
nll_loss = -logprobs * target
nll_loss = nll_loss.sum(-1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
else:
return torch.nn.functional.cross_entropy(x, target)
class OHEMSegmentationLosses(OhemCrossEntropy2d):
class SegmentationLosses(nn.CrossEntropyLoss):
"""2D Cross Entropy Loss with Auxilary Loss"""
def __init__(self, se_loss=False, se_weight=0.2, nclass=-1,
aux=False, aux_weight=0.4, weight=None,
ignore_index=-1):
super(OHEMSegmentationLosses, self).__init__(ignore_index)
super(SegmentationLosses, self).__init__(weight, None, ignore_index)
self.se_loss = se_loss
self.aux = aux
self.nclass = nclass
......@@ -133,23 +65,23 @@ class OHEMSegmentationLosses(OhemCrossEntropy2d):
def forward(self, *inputs):
if not self.se_loss and not self.aux:
return super(OHEMSegmentationLosses, self).forward(*inputs)
return super(SegmentationLosses, self).forward(*inputs)
elif not self.se_loss:
pred1, pred2, target = tuple(inputs)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
return loss1 + self.aux_weight * loss2
elif not self.aux:
pred, se_pred, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred)
loss1 = super(OHEMSegmentationLosses, self).forward(pred, target)
loss1 = super(SegmentationLosses, self).forward(pred, target)
loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.se_weight * loss2
else:
pred1, se_pred, pred2, target = tuple(inputs)
se_target = self._get_batch_label_vector(target, nclass=self.nclass).type_as(pred1)
loss1 = super(OHEMSegmentationLosses, self).forward(pred1, target)
loss2 = super(OHEMSegmentationLosses, self).forward(pred2, target)
loss1 = super(SegmentationLosses, self).forward(pred1, target)
loss2 = super(SegmentationLosses, self).forward(pred2, target)
loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
return loss1 + self.aux_weight * loss2 + self.se_weight * loss3
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -9,14 +9,10 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Util Tools"""
from .lr_scheduler import LR_Scheduler
from .metrics import SegmentationMetric, batch_intersection_union, batch_pix_accuracy
from .lr_scheduler import *
from .metrics import *
from .pallete import get_mask_pallete
from .train_helper import *
from .presets import load_image
from .files import *
from .misc import *
__all__ = ['LR_Scheduler', 'batch_pix_accuracy', 'batch_intersection_union',
'save_checkpoint', 'download', 'mkdir', 'check_sha1', 'load_image',
'get_mask_pallete', 'get_selabel_vector', 'EMA']
......@@ -10,7 +10,10 @@ __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1']
def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'):
"""Saves checkpoint to disk"""
directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname)
if hasattr(args, 'backbone'):
directory = "runs/%s/%s/%s/%s/"%(args.dataset, args.model, args.backbone, args.checkname)
else:
directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname)
if not os.path.exists(directory):
os.makedirs(directory)
filename = directory + filename
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment