"experiments/guests/gem5-pair-client-udp-0m/run.sh" did not exist on "42fdc0aa231e992635e73342cf217dee62bcd785"
Commit c04f261a authored by dongchy920's avatar dongchy920
Browse files

InstruceBLIP

parents
Pipeline #1594 canceled with stages
import torch
from ..builder import HEADS
from .fcn_head import FCNHead
try:
from annotator.uniformer.mmcv.ops import CrissCrossAttention
except ModuleNotFoundError:
CrissCrossAttention = None
@HEADS.register_module()
class CCHead(FCNHead):
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
This head is the implementation of `CCNet
<https://arxiv.org/abs/1811.11721>`_.
Args:
recurrence (int): Number of recurrence of Criss Cross Attention
module. Default: 2.
"""
def __init__(self, recurrence=2, **kwargs):
if CrissCrossAttention is None:
raise RuntimeError('Please install mmcv-full for '
'CrissCrossAttention ops')
super(CCHead, self).__init__(num_convs=2, **kwargs)
self.recurrence = recurrence
self.cca = CrissCrossAttention(self.channels)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
for _ in range(self.recurrence):
output = self.cca(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
import torch
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule, Scale
from torch import nn
from annotator.uniformer.mmseg.core import add_prefix
from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
class PAM(_SelfAttentionBlock):
"""Position Attention Module (PAM)
Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
"""
def __init__(self, in_channels, channels):
super(PAM, self).__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=1,
key_query_norm=False,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=False,
with_out=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None)
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
out = super(PAM, self).forward(x, x)
out = self.gamma(out) + x
return out
class CAM(nn.Module):
"""Channel Attention Module (CAM)"""
def __init__(self):
super(CAM, self).__init__()
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
batch_size, channels, height, width = x.size()
proj_query = x.view(batch_size, channels, -1)
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(
energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = F.softmax(energy_new, dim=-1)
proj_value = x.view(batch_size, channels, -1)
out = torch.bmm(attention, proj_value)
out = out.view(batch_size, channels, height, width)
out = self.gamma(out) + x
return out
@HEADS.register_module()
class DAHead(BaseDecodeHead):
"""Dual Attention Network for Scene Segmentation.
This head is the implementation of `DANet
<https://arxiv.org/abs/1809.02983>`_.
Args:
pam_channels (int): The channels of Position Attention Module(PAM).
"""
def __init__(self, pam_channels, **kwargs):
super(DAHead, self).__init__(**kwargs)
self.pam_channels = pam_channels
self.pam_in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.pam = PAM(self.channels, pam_channels)
self.pam_out_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.pam_conv_seg = nn.Conv2d(
self.channels, self.num_classes, kernel_size=1)
self.cam_in_conv = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.cam = CAM()
self.cam_out_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.cam_conv_seg = nn.Conv2d(
self.channels, self.num_classes, kernel_size=1)
def pam_cls_seg(self, feat):
"""PAM feature classification."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.pam_conv_seg(feat)
return output
def cam_cls_seg(self, feat):
"""CAM feature classification."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.cam_conv_seg(feat)
return output
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
pam_feat = self.pam_in_conv(x)
pam_feat = self.pam(pam_feat)
pam_feat = self.pam_out_conv(pam_feat)
pam_out = self.pam_cls_seg(pam_feat)
cam_feat = self.cam_in_conv(x)
cam_feat = self.cam(cam_feat)
cam_feat = self.cam_out_conv(cam_feat)
cam_out = self.cam_cls_seg(cam_feat)
feat_sum = pam_feat + cam_feat
pam_cam_out = self.cls_seg(feat_sum)
return pam_cam_out, pam_out, cam_out
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing, only ``pam_cam`` is used."""
return self.forward(inputs)[0]
def losses(self, seg_logit, seg_label):
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
loss = dict()
loss.update(
add_prefix(
super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
'pam_cam'))
loss.update(
add_prefix(
super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
loss.update(
add_prefix(
super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
return loss
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import normal_init
from annotator.uniformer.mmcv.runner import auto_fp16, force_fp32
from annotator.uniformer.mmseg.core import build_pixel_sampler
from annotator.uniformer.mmseg.ops import resize
from ..builder import build_loss
from ..losses import accuracy
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (int|Sequence[int]): Input channels.
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
conv_cfg (dict|None): Config of conv layers. Default: None.
norm_cfg (dict|None): Config of norm layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
in_index (int|Sequence[int]): Input feature index. Default: -1
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
Default: None.
loss_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss').
ignore_index (int | None): The label index to be ignored. When using
masked BCE loss, ignore_index should be set to None. Default: 255
sampler (dict|None): The config of segmentation map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
"""
def __init__(self,
in_channels,
channels,
*,
num_classes,
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
in_index=-1,
input_transform=None,
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
ignore_index=255,
sampler=None,
align_corners=False):
super(BaseDecodeHead, self).__init__()
self._init_inputs(in_channels, in_index, input_transform)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.in_index = in_index
self.loss_decode = build_loss(loss_decode)
self.ignore_index = ignore_index
self.align_corners = align_corners
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
if dropout_ratio > 0:
self.dropout = nn.Dropout2d(dropout_ratio)
else:
self.dropout = None
self.fp16_enabled = False
def extra_repr(self):
"""Extra repr."""
s = f'input_transform={self.input_transform}, ' \
f'ignore_index={self.ignore_index}, ' \
f'align_corners={self.align_corners}'
return s
def _init_inputs(self, in_channels, in_index, input_transform):
"""Check and initialize input transforms.
The in_channels, in_index and input_transform must match.
Specifically, when input_transform is None, only single feature map
will be selected. So in_channels and in_index must be of type int.
When input_transform
Args:
in_channels (int|Sequence[int]): Input channels.
in_index (int|Sequence[int]): Input feature index.
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
'resize_concat': Multiple feature maps will be resize to the
same size as first one and than concat together.
Usually used in FCN head of HRNet.
'multiple_select': Multiple feature maps will be bundle into
a list and passed into decode head.
None: Only one select feature map is allowed.
"""
if input_transform is not None:
assert input_transform in ['resize_concat', 'multiple_select']
self.input_transform = input_transform
self.in_index = in_index
if input_transform is not None:
assert isinstance(in_channels, (list, tuple))
assert isinstance(in_index, (list, tuple))
assert len(in_channels) == len(in_index)
if input_transform == 'resize_concat':
self.in_channels = sum(in_channels)
else:
self.in_channels = in_channels
else:
assert isinstance(in_channels, int)
assert isinstance(in_index, int)
self.in_channels = in_channels
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.conv_seg, mean=0, std=0.01)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
@auto_fp16()
@abstractmethod
def forward(self, inputs):
"""Placeholder of forward function."""
pass
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs)
def cls_seg(self, feat):
"""Classify each pixel."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
@force_fp32(apply_to=('seg_logit', ))
def losses(self, seg_logit, seg_label):
"""Compute segmentation loss."""
loss = dict()
seg_logit = resize(
input=seg_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logit, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
loss['loss_seg'] = self.loss_decode(
seg_logit,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(seg_logit, seg_label)
return loss
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class DCM(nn.Module):
"""Dynamic Convolutional Module used in DMNet.
Args:
filter_size (int): The filter size of generated convolution kernel
used in Dynamic Convolutional Module.
fusion (bool): Add one conv to fuse DCM output feature.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict | None): Config of conv layers.
norm_cfg (dict | None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
norm_cfg, act_cfg):
super(DCM, self).__init__()
self.filter_size = filter_size
self.fusion = fusion
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
0)
self.input_redu_conv = ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.norm_cfg is not None:
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
else:
self.norm = None
self.activate = build_activation_layer(self.act_cfg)
if self.fusion:
self.fusion_conv = ConvModule(
self.channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, x):
"""Forward function."""
generated_filter = self.filter_gen_conv(
F.adaptive_avg_pool2d(x, self.filter_size))
x = self.input_redu_conv(x)
b, c, h, w = x.shape
# [1, b * c, h, w], c = self.channels
x = x.view(1, b * c, h, w)
# [b * c, 1, filter_size, filter_size]
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
self.filter_size)
pad = (self.filter_size - 1) // 2
if (self.filter_size - 1) % 2 == 0:
p2d = (pad, pad, pad, pad)
else:
p2d = (pad + 1, pad, pad + 1, pad)
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
# [1, b * c, h, w]
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
# [b, c, h, w]
output = output.view(b, c, h, w)
if self.norm is not None:
output = self.norm(output)
output = self.activate(output)
if self.fusion:
output = self.fusion_conv(output)
return output
@HEADS.register_module()
class DMHead(BaseDecodeHead):
"""Dynamic Multi-scale Filters for Semantic Segmentation.
This head is the implementation of
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
ICCV_2019_paper.pdf>`_.
Args:
filter_sizes (tuple[int]): The size of generated convolutional filters
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
fusion (bool): Add one conv to fuse DCM output feature.
"""
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
super(DMHead, self).__init__(**kwargs)
assert isinstance(filter_sizes, (list, tuple))
self.filter_sizes = filter_sizes
self.fusion = fusion
dcm_modules = []
for filter_size in self.filter_sizes:
dcm_modules.append(
DCM(filter_size,
self.fusion,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.dcm_modules = nn.ModuleList(dcm_modules)
self.bottleneck = ConvModule(
self.in_channels + len(filter_sizes) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
dcm_outs = [x]
for dcm_module in self.dcm_modules:
dcm_outs.append(dcm_module(x))
dcm_outs = torch.cat(dcm_outs, dim=1)
output = self.bottleneck(dcm_outs)
output = self.cls_seg(output)
return output
import torch
from annotator.uniformer.mmcv.cnn import NonLocal2d
from torch import nn
from ..builder import HEADS
from .fcn_head import FCNHead
class DisentangledNonLocal2d(NonLocal2d):
"""Disentangled Non-Local Blocks.
Args:
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self, *arg, temperature, **kwargs):
super().__init__(*arg, **kwargs)
self.temperature = temperature
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
def embedded_gaussian(self, theta_x, phi_x):
"""Embedded gaussian with temperature."""
# NonLocal2d pairwise_weight: [N, HxW, HxW]
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight /= self.temperature
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight
def forward(self, x):
# x: [N, C, H, W]
n = x.size(0)
# g_x: [N, HxW, C]
g_x = self.g(x).view(n, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
if self.mode == 'gaussian':
theta_x = x.view(n, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(n, self.in_channels, -1)
else:
phi_x = x.view(n, self.in_channels, -1)
elif self.mode == 'concatenation':
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
else:
theta_x = self.theta(x).view(n, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(n, self.inter_channels, -1)
# subtract mean
theta_x -= theta_x.mean(dim=-2, keepdim=True)
phi_x -= phi_x.mean(dim=-1, keepdim=True)
pairwise_func = getattr(self, self.mode)
# pairwise_weight: [N, HxW, HxW]
pairwise_weight = pairwise_func(theta_x, phi_x)
# y: [N, HxW, C]
y = torch.matmul(pairwise_weight, g_x)
# y: [N, C, H, W]
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
*x.size()[2:])
# unary_mask: [N, 1, HxW]
unary_mask = self.conv_mask(x)
unary_mask = unary_mask.view(n, 1, -1)
unary_mask = unary_mask.softmax(dim=-1)
# unary_x: [N, 1, C]
unary_x = torch.matmul(unary_mask, g_x)
# unary_x: [N, C, 1, 1]
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
n, self.inter_channels, 1, 1)
output = x + self.conv_out(y + unary_x)
return output
@HEADS.register_module()
class DNLHead(FCNHead):
"""Disentangled Non-Local Neural Networks.
This head is the implementation of `DNLNet
<https://arxiv.org/abs/2006.06668>`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: False.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
temperature (float): Temperature to adjust attention. Default: 0.05
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
temperature=0.05,
**kwargs):
super(DNLHead, self).__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.temperature = temperature
self.dnl_block = DisentangledNonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode,
temperature=self.temperature)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.dnl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
import math
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
def reduce_mean(tensor):
"""Reduce mean when distributed training."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.
Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
"""
def __init__(self, channels, num_bases, num_stages, momentum):
super(EMAModule, self).__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, channels, num_bases]
bases = F.normalize(bases, dim=1, p=2)
self.register_buffer('bases', bases)
def forward(self, feats):
"""Forward function."""
batch_size, channels, height, width = feats.size()
# [batch_size, channels, height*width]
feats = feats.view(batch_size, channels, height * width)
# [batch_size, channels, num_bases]
bases = self.bases.repeat(batch_size, 1, 1)
with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
# l1 norm
attention_normed = F.normalize(attention, dim=1, p=1)
# [batch_size, channels, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, channels, height, width)
if self.training:
bases = bases.mean(dim=0, keepdim=True)
bases = reduce_mean(bases)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
self.bases = (1 -
self.momentum) * self.bases + self.momentum * bases
return feats_recon
@HEADS.register_module()
class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation.
This head is the implementation of `EMANet
<https://arxiv.org/abs/1907.13426>`_.
Args:
ema_channels (int): EMA module channels
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
concat_input (bool): Whether concat the input and output of convs
before classification layer. Default: True
momentum (float): Momentum to update the base. Default: 0.1.
"""
def __init__(self,
ema_channels,
num_bases,
num_stages,
concat_input=True,
momentum=0.1,
**kwargs):
super(EMAHead, self).__init__(**kwargs)
self.ema_channels = ema_channels
self.num_bases = num_bases
self.num_stages = num_stages
self.concat_input = concat_input
self.momentum = momentum
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum)
self.ema_in_conv = ConvModule(
self.in_channels,
self.ema_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# project (0, inf) -> (-inf, inf)
self.ema_mid_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=None)
for param in self.ema_mid_conv.parameters():
param.requires_grad = False
self.ema_out_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.bottleneck = ConvModule(
self.ema_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.ema_in_conv(x)
identity = feats
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = self.bottleneck(output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule, build_norm_layer
from annotator.uniformer.mmseg.ops import Encoding, resize
from ..builder import HEADS, build_loss
from .decode_head import BaseDecodeHead
class EncModule(nn.Module):
"""Encoding Module used in EncNet.
Args:
in_channels (int): Input channels.
num_codes (int): Number of code words.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
"""
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
super(EncModule, self).__init__()
self.encoding_project = ConvModule(
in_channels,
in_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
# TODO: resolve this hack
# change to 1d
if norm_cfg is not None:
encoding_norm_cfg = norm_cfg.copy()
if encoding_norm_cfg['type'] in ['BN', 'IN']:
encoding_norm_cfg['type'] += '1d'
else:
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
'2d', '1d')
else:
# fallback to BN1d
encoding_norm_cfg = dict(type='BN1d')
self.encoding = nn.Sequential(
Encoding(channels=in_channels, num_codes=num_codes),
build_norm_layer(encoding_norm_cfg, num_codes)[1],
nn.ReLU(inplace=True))
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels), nn.Sigmoid())
def forward(self, x):
"""Forward function."""
encoding_projection = self.encoding_project(x)
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
batch_size, channels, _, _ = x.size()
gamma = self.fc(encoding_feat)
y = gamma.view(batch_size, channels, 1, 1)
output = F.relu_(x + x * y)
return encoding_feat, output
@HEADS.register_module()
class EncHead(BaseDecodeHead):
"""Context Encoding for Semantic Segmentation.
This head is the implementation of `EncNet
<https://arxiv.org/abs/1803.08904>`_.
Args:
num_codes (int): Number of code words. Default: 32.
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
regularize the training. Default: True.
add_lateral (bool): Whether use lateral connection to fuse features.
Default: False.
loss_se_decode (dict): Config of decode loss.
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
"""
def __init__(self,
num_codes=32,
use_se_loss=True,
add_lateral=False,
loss_se_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=0.2),
**kwargs):
super(EncHead, self).__init__(
input_transform='multiple_select', **kwargs)
self.use_se_loss = use_se_loss
self.add_lateral = add_lateral
self.num_codes = num_codes
self.bottleneck = ConvModule(
self.in_channels[-1],
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if add_lateral:
self.lateral_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the last one
self.lateral_convs.append(
ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.fusion = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.enc_module = EncModule(
self.channels,
num_codes=num_codes,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.use_se_loss:
self.loss_se_decode = build_loss(loss_se_decode)
self.se_layer = nn.Linear(self.channels, self.num_classes)
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
feat = self.bottleneck(inputs[-1])
if self.add_lateral:
laterals = [
resize(
lateral_conv(inputs[i]),
size=feat.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
for i, lateral_conv in enumerate(self.lateral_convs)
]
feat = self.fusion(torch.cat([feat, *laterals], 1))
encode_feat, output = self.enc_module(feat)
output = self.cls_seg(output)
if self.use_se_loss:
se_output = self.se_layer(encode_feat)
return output, se_output
else:
return output
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing, ignore se_loss."""
if self.use_se_loss:
return self.forward(inputs)[0]
else:
return self.forward(inputs)
@staticmethod
def _convert_to_onehot_labels(seg_label, num_classes):
"""Convert segmentation label to onehot.
Args:
seg_label (Tensor): Segmentation label of shape (N, H, W).
num_classes (int): Number of classes.
Returns:
Tensor: Onehot labels of shape (N, num_classes).
"""
batch_size = seg_label.size(0)
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
for i in range(batch_size):
hist = seg_label[i].float().histc(
bins=num_classes, min=0, max=num_classes - 1)
onehot_labels[i] = hist > 0
return onehot_labels
def losses(self, seg_logit, seg_label):
"""Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit
loss = dict()
loss.update(super(EncHead, self).losses(seg_logit, seg_label))
se_loss = self.loss_se_decode(
se_seg_logit,
self._convert_to_onehot_labels(seg_label, self.num_classes))
loss['loss_se'] = se_loss
return loss
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class FCNHead(BaseDecodeHead):
"""Fully Convolution Networks for Semantic Segmentation.
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
Args:
num_convs (int): Number of convs in the head. Default: 2.
kernel_size (int): The kernel size for convs in the head. Default: 3.
concat_input (bool): Whether concat the input and output of convs
before classification layer.
dilation (int): The dilation rate for convs in the head. Default: 1.
"""
def __init__(self,
num_convs=2,
kernel_size=3,
concat_input=True,
dilation=1,
**kwargs):
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
self.num_convs = num_convs
self.concat_input = concat_input
self.kernel_size = kernel_size
super(FCNHead, self).__init__(**kwargs)
if num_convs == 0:
assert self.in_channels == self.channels
conv_padding = (kernel_size // 2) * dilation
convs = []
convs.append(
ConvModule(
self.in_channels,
self.channels,
kernel_size=kernel_size,
padding=conv_padding,
dilation=dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
for i in range(num_convs - 1):
convs.append(
ConvModule(
self.channels,
self.channels,
kernel_size=kernel_size,
padding=conv_padding,
dilation=dilation,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if num_convs == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs(x)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
import numpy as np
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks.
This head is the implementation of `Semantic FPN
<https://arxiv.org/abs/1901.02446>`_.
Args:
feature_strides (tuple[int]): The strides for input feature maps.
stack_lateral. All strides suppose to be power of 2. The first
one is of largest resolution.
"""
def __init__(self, feature_strides, **kwargs):
super(FPNHead, self).__init__(
input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
self.scale_heads = nn.ModuleList()
for i in range(len(feature_strides)):
head_length = max(
1,
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
scale_head = []
for k in range(head_length):
scale_head.append(
ConvModule(
self.in_channels[i] if k == 0 else self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
nn.Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
self.scale_heads.append(nn.Sequential(*scale_head))
def forward(self, inputs):
x = self._transform_inputs(inputs)
output = self.scale_heads[0](x[0])
for i in range(1, len(self.feature_strides)):
# non inplace
output = output + resize(
self.scale_heads[i](x[i]),
size=output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.cls_seg(output)
return output
import torch
from annotator.uniformer.mmcv.cnn import ContextBlock
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class GCHead(FCNHead):
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
This head is the implementation of `GCNet
<https://arxiv.org/abs/1904.11492>`_.
Args:
ratio (float): Multiplier of channels ratio. Default: 1/4.
pooling_type (str): The pooling type of context aggregation.
Options are 'att', 'avg'. Default: 'avg'.
fusion_types (tuple[str]): The fusion type for feature fusion.
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
"""
def __init__(self,
ratio=1 / 4.,
pooling_type='att',
fusion_types=('channel_add', ),
**kwargs):
super(GCHead, self).__init__(num_convs=2, **kwargs)
self.ratio = ratio
self.pooling_type = pooling_type
self.fusion_types = fusion_types
self.gc_block = ContextBlock(
in_channels=self.channels,
ratio=self.ratio,
pooling_type=self.pooling_type,
fusion_types=self.fusion_types)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.gc_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
import torch
import torch.nn as nn
from annotator.uniformer.mmcv import is_tuple_of
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
class LRASPPHead(BaseDecodeHead):
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
This head is the improved implementation of `Searching for MobileNetV3
<https://ieeexplore.ieee.org/document/9008835>`_.
Args:
branch_channels (tuple[int]): The number of output channels in every
each branch. Default: (32, 64).
"""
def __init__(self, branch_channels=(32, 64), **kwargs):
super(LRASPPHead, self).__init__(**kwargs)
if self.input_transform != 'multiple_select':
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
f'must be \'multiple_select\'. But received '
f'\'{self.input_transform}\'')
assert is_tuple_of(branch_channels, int)
assert len(branch_channels) == len(self.in_channels) - 1
self.branch_channels = branch_channels
self.convs = nn.Sequential()
self.conv_ups = nn.Sequential()
for i in range(len(branch_channels)):
self.convs.add_module(
f'conv{i}',
nn.Conv2d(
self.in_channels[i], branch_channels[i], 1, bias=False))
self.conv_ups.add_module(
f'conv_up{i}',
ConvModule(
self.channels + branch_channels[i],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False))
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
self.aspp_conv = ConvModule(
self.in_channels[-1],
self.channels,
1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
bias=False)
self.image_pool = nn.Sequential(
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
ConvModule(
self.in_channels[2],
self.channels,
1,
act_cfg=dict(type='Sigmoid'),
bias=False))
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
x = inputs[-1]
x = self.aspp_conv(x) * resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = self.conv_up_input(x)
for i in range(len(self.branch_channels) - 1, -1, -1):
x = resize(
x,
size=inputs[i].size()[2:],
mode='bilinear',
align_corners=self.align_corners)
x = torch.cat([x, self.convs[i](inputs[i])], 1)
x = self.conv_ups[i](x)
return self.cls_seg(x)
import torch
from annotator.uniformer.mmcv.cnn import NonLocal2d
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class NLHead(FCNHead):
"""Non-local Neural Networks.
This head is the implementation of `NLNet
<https://arxiv.org/abs/1711.07971>`_.
Args:
reduction (int): Reduction factor of projection transform. Default: 2.
use_scale (bool): Whether to scale pairwise_weight by
sqrt(1/inter_channels). Default: True.
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
'dot_product'. Default: 'embedded_gaussian.'.
"""
def __init__(self,
reduction=2,
use_scale=True,
mode='embedded_gaussian',
**kwargs):
super(NLHead, self).__init__(num_convs=2, **kwargs)
self.reduction = reduction
self.use_scale = use_scale
self.mode = mode
self.nl_block = NonLocal2d(
in_channels=self.channels,
reduction=self.reduction,
use_scale=self.use_scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
mode=self.mode)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
output = self.nl_block(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .cascade_decode_head import BaseCascadeDecodeHead
class SpatialGatherModule(nn.Module):
"""Aggregate the context features according to the initial predicted
probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, scale):
super(SpatialGatherModule, self).__init__()
self.scale = scale
def forward(self, feats, probs):
"""Forward function."""
batch_size, num_classes, height, width = probs.size()
channels = feats.size(1)
probs = probs.view(batch_size, num_classes, -1)
feats = feats.view(batch_size, channels, -1)
# [batch_size, height*width, num_classes]
feats = feats.permute(0, 2, 1)
# [batch_size, channels, height*width]
probs = F.softmax(self.scale * probs, dim=2)
# [batch_size, channels, num_classes]
ocr_context = torch.matmul(probs, feats)
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
return ocr_context
class ObjectAttentionBlock(_SelfAttentionBlock):
"""Make a OCR used SelfAttentionBlock."""
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
act_cfg):
if scale > 1:
query_downsample = nn.MaxPool2d(kernel_size=scale)
else:
query_downsample = None
super(ObjectAttentionBlock, self).__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=query_downsample,
key_downsample=None,
key_query_num_convs=2,
key_query_norm=True,
value_out_num_convs=1,
value_out_norm=True,
matmul_norm=True,
with_out=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.bottleneck = ConvModule(
in_channels * 2,
in_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, query_feats, key_feats):
"""Forward function."""
context = super(ObjectAttentionBlock,
self).forward(query_feats, key_feats)
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
if self.query_downsample is not None:
output = resize(query_feats)
return output
@HEADS.register_module()
class OCRHead(BaseCascadeDecodeHead):
"""Object-Contextual Representations for Semantic Segmentation.
This head is the implementation of `OCRNet
<https://arxiv.org/abs/1909.11065>`_.
Args:
ocr_channels (int): The intermediate channels of OCR block.
scale (int): The scale of probability map in SpatialGatherModule in
Default: 1.
"""
def __init__(self, ocr_channels, scale=1, **kwargs):
super(OCRHead, self).__init__(**kwargs)
self.ocr_channels = ocr_channels
self.scale = scale
self.object_context_block = ObjectAttentionBlock(
self.channels,
self.ocr_channels,
self.scale,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.spatial_gather_module = SpatialGatherModule(self.scale)
self.bottleneck = ConvModule(
self.in_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs, prev_output):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.bottleneck(x)
context = self.spatial_gather_module(feats, prev_output)
object_context = self.object_context_block(feats, context)
output = self.cls_seg(object_context)
return output
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule, normal_init
from annotator.uniformer.mmcv.ops import point_sample
from annotator.uniformer.mmseg.models.builder import HEADS
from annotator.uniformer.mmseg.ops import resize
from ..losses import accuracy
from .cascade_decode_head import BaseCascadeDecodeHead
def calculate_uncertainty(seg_logits):
"""Estimate uncertainty based on seg logits.
For each location of the prediction ``seg_logits`` we estimate
uncertainty as the difference between top first and top second
predicted logits.
Args:
seg_logits (Tensor): Semantic segmentation logits,
shape (batch_size, num_classes, height, width).
Returns:
scores (Tensor): T uncertainty scores with the most uncertain
locations having the highest uncertainty score, shape (
batch_size, 1, height, width)
"""
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@HEADS.register_module()
class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend.
``PointHead`` use shared multi-layer perceptron (equivalent to
nn.Conv1d) to predict the logit of input points. The fine-grained feature
and coarse feature will be concatenate together for predication.
Args:
num_fcs (int): Number of fc layers in the head. Default: 3.
in_channels (int): Number of input channels. Default: 256.
fc_channels (int): Number of fc channels. Default: 256.
num_classes (int): Number of classes for logits. Default: 80.
class_agnostic (bool): Whether use class agnostic classification.
If so, the output channels of logits will be 1. Default: False.
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
the output of each fc layer. Default: True.
conv_cfg (dict|None): Dictionary to construct and config conv layer.
Default: dict(type='Conv1d'))
norm_cfg (dict|None): Dictionary to construct and config norm layer.
Default: None.
loss_point (dict): Dictionary to construct and config loss layer of
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
loss_weight=1.0).
"""
def __init__(self,
num_fcs=3,
coarse_pred_each_layer=True,
conv_cfg=dict(type='Conv1d'),
norm_cfg=None,
act_cfg=dict(type='ReLU', inplace=False),
**kwargs):
super(PointHead, self).__init__(
input_transform='multiple_select',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.num_fcs = num_fcs
self.coarse_pred_each_layer = coarse_pred_each_layer
fc_in_channels = sum(self.in_channels) + self.num_classes
fc_channels = self.channels
self.fcs = nn.ModuleList()
for k in range(num_fcs):
fc = ConvModule(
fc_in_channels,
fc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.fcs.append(fc)
fc_in_channels = fc_channels
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
else 0
self.fc_seg = nn.Conv1d(
fc_in_channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0)
if self.dropout_ratio > 0:
self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg')
def init_weights(self):
"""Initialize weights of classification layer."""
normal_init(self.fc_seg, std=0.001)
def cls_seg(self, feat):
"""Classify each pixel with fc."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.fc_seg(feat)
return output
def forward(self, fine_grained_point_feats, coarse_point_feats):
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
for fc in self.fcs:
x = fc(x)
if self.coarse_pred_each_layer:
x = torch.cat((x, coarse_point_feats), dim=1)
return self.cls_seg(x)
def _get_fine_grained_point_feats(self, x, points):
"""Sample from fine grained features.
Args:
x (list[Tensor]): Feature pyramid from by neck or backbone.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
fine_grained_feats (Tensor): Sampled fine grained feature,
shape (batch_size, sum(channels of x), num_points).
"""
fine_grained_feats_list = [
point_sample(_, points, align_corners=self.align_corners)
for _ in x
]
if len(fine_grained_feats_list) > 1:
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
else:
fine_grained_feats = fine_grained_feats_list[0]
return fine_grained_feats
def _get_coarse_point_feats(self, prev_output, points):
"""Sample from fine grained features.
Args:
prev_output (list[Tensor]): Prediction of previous decode head.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
num_classes, num_points).
"""
coarse_feats = point_sample(
prev_output, points, align_corners=self.align_corners)
return coarse_feats
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self._transform_inputs(inputs)
with torch.no_grad():
points = self.get_points_train(
prev_output, calculate_uncertainty, cfg=train_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_label = point_sample(
gt_semantic_seg.float(),
points,
mode='nearest',
align_corners=self.align_corners)
point_label = point_label.squeeze(1).long()
losses = self.losses(point_logits, point_label)
return losses
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
x = self._transform_inputs(inputs)
refined_seg_logits = prev_output.clone()
for _ in range(test_cfg.subdivision_steps):
refined_seg_logits = resize(
refined_seg_logits,
scale_factor=test_cfg.scale_factor,
mode='bilinear',
align_corners=self.align_corners)
batch_size, channels, height, width = refined_seg_logits.shape
point_indices, points = self.get_points_test(
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
refined_seg_logits = refined_seg_logits.reshape(
batch_size, channels, height * width)
refined_seg_logits = refined_seg_logits.scatter_(
2, point_indices, point_logits)
refined_seg_logits = refined_seg_logits.view(
batch_size, channels, height, width)
return refined_seg_logits
def losses(self, point_logits, point_label):
"""Compute segmentation loss."""
loss = dict()
loss['loss_point'] = self.loss_decode(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(point_logits, point_label)
return loss
def get_points_train(self, seg_logits, uncertainty_func, cfg):
"""Sample points for training.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'uncertainty_func' function that takes point's logit prediction as
input.
Args:
seg_logits (Tensor): Semantic segmentation logits, shape (
batch_size, num_classes, height, width).
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Training config of point head.
Returns:
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains the coordinates of ``num_points`` sampled
points.
"""
num_points = cfg.num_points
oversample_ratio = cfg.oversample_ratio
importance_sample_ratio = cfg.importance_sample_ratio
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = seg_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=seg_logits.device)
point_logits = point_sample(seg_logits, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=seg_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_point_coords = torch.rand(
batch_size, num_random_points, 2, device=seg_logits.device)
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
return point_coords
def get_points_test(self, seg_logits, uncertainty_func, cfg):
"""Sample points for testing.
Find ``num_points`` most uncertain points from ``uncertainty_map``.
Args:
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
height, width) for class-specific or class-agnostic prediction.
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Testing config of point head.
Returns:
point_indices (Tensor): A tensor of shape (batch_size, num_points)
that contains indices from [0, height x width) of the most
uncertain points.
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the ``height x width`` grid .
"""
num_points = cfg.subdivision_num_points
uncertainty_map = uncertainty_func(seg_logits)
batch_size, _, height, width = uncertainty_map.shape
h_step = 1.0 / height
w_step = 1.0 / width
uncertainty_map = uncertainty_map.view(batch_size, height * width)
num_points = min(height * width, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
point_coords = torch.zeros(
batch_size,
num_points,
2,
dtype=torch.float,
device=seg_logits.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
width).float() * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
width).float() * h_step
return point_indices, point_coords
import torch
import torch.nn as nn
import torch.nn.functional as F
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
try:
from annotator.uniformer.mmcv.ops import PSAMask
except ModuleNotFoundError:
PSAMask = None
@HEADS.register_module()
class PSAHead(BaseDecodeHead):
"""Point-wise Spatial Attention Network for Scene Parsing.
This head is the implementation of `PSANet
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
Args:
mask_size (tuple[int]): The PSA mask size. It usually equals input
size.
psa_type (str): The type of psa module. Options are 'collect',
'distribute', 'bi-direction'. Default: 'bi-direction'
compact (bool): Whether use compact map for 'collect' mode.
Default: True.
shrink_factor (int): The downsample factors of psa mask. Default: 2.
normalization_factor (float): The normalize factor of attention.
psa_softmax (bool): Whether use softmax for attention.
"""
def __init__(self,
mask_size,
psa_type='bi-direction',
compact=False,
shrink_factor=2,
normalization_factor=1.0,
psa_softmax=True,
**kwargs):
if PSAMask is None:
raise RuntimeError('Please install mmcv-full for PSAMask ops')
super(PSAHead, self).__init__(**kwargs)
assert psa_type in ['collect', 'distribute', 'bi-direction']
self.psa_type = psa_type
self.compact = compact
self.shrink_factor = shrink_factor
self.mask_size = mask_size
mask_h, mask_w = mask_size
self.psa_softmax = psa_softmax
if normalization_factor is None:
normalization_factor = mask_h * mask_w
self.normalization_factor = normalization_factor
self.reduce = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
if psa_type == 'bi-direction':
self.reduce_p = ConvModule(
self.in_channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.attention_p = nn.Sequential(
ConvModule(
self.channels,
self.channels,
kernel_size=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Conv2d(
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
self.psamask_collect = PSAMask('collect', mask_size)
self.psamask_distribute = PSAMask('distribute', mask_size)
else:
self.psamask = PSAMask(psa_type, mask_size)
self.proj = ConvModule(
self.channels * (2 if psa_type == 'bi-direction' else 1),
self.in_channels,
kernel_size=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.bottleneck = ConvModule(
self.in_channels * 2,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
identity = x
align_corners = self.align_corners
if self.psa_type in ['collect', 'distribute']:
out = self.reduce(x)
n, c, h, w = out.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
out = resize(
out,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y = self.attention(out)
if self.compact:
if self.psa_type == 'collect':
y = y.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y = self.psamask(y)
if self.psa_softmax:
y = F.softmax(y, dim=1)
out = torch.bmm(
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
else:
x_col = self.reduce(x)
x_dis = self.reduce_p(x)
n, c, h, w = x_col.size()
if self.shrink_factor != 1:
if h % self.shrink_factor and w % self.shrink_factor:
h = (h - 1) // self.shrink_factor + 1
w = (w - 1) // self.shrink_factor + 1
align_corners = True
else:
h = h // self.shrink_factor
w = w // self.shrink_factor
align_corners = False
x_col = resize(
x_col,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
x_dis = resize(
x_dis,
size=(h, w),
mode='bilinear',
align_corners=align_corners)
y_col = self.attention(x_col)
y_dis = self.attention_p(x_dis)
if self.compact:
y_dis = y_dis.view(n, h * w,
h * w).transpose(1, 2).view(n, h * w, h, w)
else:
y_col = self.psamask_collect(y_col)
y_dis = self.psamask_distribute(y_dis)
if self.psa_softmax:
y_col = F.softmax(y_col, dim=1)
y_dis = F.softmax(y_dis, dim=1)
x_col = torch.bmm(
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
x_dis = torch.bmm(
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
n, c, h, w) * (1.0 / self.normalization_factor)
out = torch.cat([x_col, x_dis], 1)
out = self.proj(out)
out = resize(
out,
size=identity.shape[2:],
mode='bilinear',
align_corners=align_corners)
out = self.bottleneck(torch.cat((identity, out), dim=1))
out = self.cls_seg(out)
return out
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
class PPM(nn.ModuleList):
"""Pooling Pyramid Module used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
conv_cfg (dict|None): Config of conv layers.
norm_cfg (dict|None): Config of norm layers.
act_cfg (dict): Config of activation layers.
align_corners (bool): align_corners argument of F.interpolate.
"""
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
act_cfg, align_corners):
super(PPM, self).__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
ConvModule(
self.in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)))
def forward(self, x):
"""Forward function."""
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = resize(
ppm_out,
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
@HEADS.register_module()
class PSPHead(BaseDecodeHead):
"""Pyramid Scene Parsing Network.
This head is the implementation of
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
assert isinstance(pool_scales, (list, tuple))
self.pool_scales = pool_scales
self.psp_modules = PPM(
self.pool_scales,
self.in_channels,
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
output = self.cls_seg(output)
return output
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .aspp_head import ASPPHead, ASPPModule
class DepthwiseSeparableASPPModule(ASPPModule):
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
conv."""
def __init__(self, **kwargs):
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
for i, dilation in enumerate(self.dilations):
if dilation > 1:
self[i] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
3,
dilation=dilation,
padding=dilation,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
@HEADS.register_module()
class DepthwiseSeparableASPPHead(ASPPHead):
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation.
This head is the implementation of `DeepLabV3+
<https://arxiv.org/abs/1802.02611>`_.
Args:
c1_in_channels (int): The input channels of c1 decoder. If is 0,
the no decoder will be used.
c1_channels (int): The intermediate channels of c1 decoder.
"""
def __init__(self, c1_in_channels, c1_channels, **kwargs):
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
assert c1_in_channels >= 0
self.aspp_modules = DepthwiseSeparableASPPModule(
dilations=self.dilations,
in_channels=self.in_channels,
channels=self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if c1_in_channels > 0:
self.c1_bottleneck = ConvModule(
c1_in_channels,
c1_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
else:
self.c1_bottleneck = None
self.sep_bottleneck = nn.Sequential(
DepthwiseSeparableConvModule(
self.channels + c1_channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
DepthwiseSeparableConvModule(
self.channels,
self.channels,
3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
aspp_outs = [
resize(
self.image_pool(x),
size=x.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
]
aspp_outs.extend(self.aspp_modules(x))
aspp_outs = torch.cat(aspp_outs, dim=1)
output = self.bottleneck(aspp_outs)
if self.c1_bottleneck is not None:
c1_output = self.c1_bottleneck(inputs[0])
output = resize(
input=output,
size=c1_output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = torch.cat([output, c1_output], dim=1)
output = self.sep_bottleneck(output)
output = self.cls_seg(output)
return output
from annotator.uniformer.mmcv.cnn import DepthwiseSeparableConvModule
from ..builder import HEADS
from .fcn_head import FCNHead
@HEADS.register_module()
class DepthwiseSeparableFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation.
This head is implemented according to Fast-SCNN paper.
Args:
in_channels(int): Number of output channels of FFM.
channels(int): Number of middle-stage channels in the decode head.
concat_input(bool): Whether to concatenate original decode input into
the result of several consecutive convolution layers.
Default: True.
num_classes(int): Used to determine the dimension of
final prediction tensor.
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
norm_cfg (dict | None): Config of norm layers.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
loss_decode(dict): Config of loss type and some
relevant additional options.
"""
def __init__(self, **kwargs):
super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
self.convs[0] = DepthwiseSeparableConvModule(
self.in_channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)
for i in range(1, self.num_convs):
self.convs[i] = DepthwiseSeparableConvModule(
self.channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)
if self.concat_input:
self.conv_cat = DepthwiseSeparableConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=self.kernel_size,
padding=self.kernel_size // 2,
norm_cfg=self.norm_cfg)
import torch
import torch.nn as nn
from annotator.uniformer.mmcv.cnn import ConvModule
from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
from .psp_head import PPM
@HEADS.register_module()
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.
This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module applied on the last feature. Default: (1, 2, 3, 6).
"""
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
super(UPerHead, self).__init__(
input_transform='multiple_select', **kwargs)
# PSP Module
self.psp_modules = PPM(
pool_scales,
self.in_channels[-1],
self.channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.bottleneck = ConvModule(
self.in_channels[-1] + len(pool_scales) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
fpn_conv = ConvModule(
self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = ConvModule(
len(self.in_channels) * self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
def psp_forward(self, inputs):
"""Forward function of PSP module."""
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, inputs):
"""Forward function."""
inputs = self._transform_inputs(inputs)
# build laterals
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals.append(self.psp_forward(inputs))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += resize(
laterals[i],
size=prev_shape,
mode='bilinear',
align_corners=self.align_corners)
# build outputs
fpn_outs = [
self.fpn_convs[i](laterals[i])
for i in range(used_backbone_levels - 1)
]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.cls_seg(output)
return output
from .accuracy import Accuracy, accuracy
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .lovasz_loss import LovaszLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
__all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment