Commit 57f6da5c authored by bailuo's avatar bailuo
Browse files

readme

parents
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import distance2bbox, force_fp32, multi_apply, multiclass_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import ConvModule, Scale, bias_init_with_prob
INF = 1e8
@HEADS.register_module
class FCOSHead(nn.Module):
"""
Fully Convolutional One-Stage Object Detection head from [1]_.
The FCOS head does not use anchor boxes. Instead bounding boxes are
predicted at each pixel and a centerness measure is used to supress
low-quality predictions.
References:
.. [1] https://arxiv.org/abs/1904.01355
Example:
>>> self = FCOSHead(11, 7)
>>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
>>> cls_score, bbox_pred, centerness = self.forward(feats)
>>> assert len(cls_score) == len(self.scales)
"""
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
(512, INF)),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)):
super(FCOSHead, self).__init__()
self.num_classes = num_classes
self.cls_out_channels = num_classes - 1
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.regress_ranges = regress_ranges
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.loss_centerness = build_loss(loss_centerness)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self._init_layers()
def _init_layers(self):
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.fcos_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
self.fcos_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
self.fcos_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.fcos_cls, std=0.01, bias=bias_cls)
normal_init(self.fcos_reg, std=0.01)
normal_init(self.fcos_centerness, std=0.01)
def forward(self, feats):
return multi_apply(self.forward_single, feats, self.scales)
def forward_single(self, x, scale):
cls_feat = x
reg_feat = x
for cls_layer in self.cls_convs:
cls_feat = cls_layer(cls_feat)
cls_score = self.fcos_cls(cls_feat)
centerness = self.fcos_centerness(cls_feat)
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
# scale the bbox_pred of different level
# float to avoid overflow when enabling FP16
bbox_pred = scale(self.fcos_reg(reg_feat)).float().exp()
return cls_score, bbox_pred, centerness
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def loss(self,
cls_scores,
bbox_preds,
centernesses,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
assert len(cls_scores) == len(bbox_preds) == len(centernesses)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
gt_labels)
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds and centerness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_centerness = torch.cat(flatten_centerness)
flatten_labels = torch.cat(labels)
flatten_bbox_targets = torch.cat(bbox_targets)
# repeat points to align with bbox_preds
flatten_points = torch.cat(
[points.repeat(num_imgs, 1) for points in all_level_points])
pos_inds = flatten_labels.nonzero().reshape(-1)
num_pos = len(pos_inds)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels,
avg_factor=num_pos + num_imgs) # avoid num_pos is 0
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_centerness = flatten_centerness[pos_inds]
if num_pos > 0:
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_centerness_targets = self.centerness_target(pos_bbox_targets)
pos_points = flatten_points[pos_inds]
pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
pos_decoded_target_preds = distance2bbox(pos_points,
pos_bbox_targets)
# centerness weighted iou loss
loss_bbox = self.loss_bbox(
pos_decoded_bbox_preds,
pos_decoded_target_preds,
weight=pos_centerness_targets,
avg_factor=pos_centerness_targets.sum())
loss_centerness = self.loss_centerness(pos_centerness,
pos_centerness_targets)
else:
loss_bbox = pos_bbox_preds.sum()
loss_centerness = pos_centerness.sum()
return dict(
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness)
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
def get_bboxes(self,
cls_scores,
bbox_preds,
centernesses,
img_metas,
cfg,
rescale=None):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
centerness_pred_list = [
centernesses[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
det_bboxes = self.get_bboxes_single(cls_score_list, bbox_pred_list,
centerness_pred_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale)
result_list.append(det_bboxes)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
centernesses,
mlvl_points,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_bboxes = []
mlvl_scores = []
mlvl_centerness = []
for cls_score, bbox_pred, centerness, points in zip(
cls_scores, bbox_preds, centernesses, mlvl_points):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
max_scores, _ = (scores * centerness[:, None]).max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
points = points[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
centerness = centerness[topk_inds]
bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_centerness.append(centerness)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
mlvl_centerness = torch.cat(mlvl_centerness)
det_bboxes, det_labels = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness)
return det_bboxes, det_labels
def get_points(self, featmap_sizes, dtype, device):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
dtype (torch.dtype): Type of points.
device (torch.device): Device of points.
Returns:
tuple: points of each image.
"""
mlvl_points = []
for i in range(len(featmap_sizes)):
mlvl_points.append(
self.get_points_single(featmap_sizes[i], self.strides[i],
dtype, device))
return mlvl_points
def get_points_single(self, featmap_size, stride, dtype, device):
h, w = featmap_size
x_range = torch.arange(
0, w * stride, stride, dtype=dtype, device=device)
y_range = torch.arange(
0, h * stride, stride, dtype=dtype, device=device)
y, x = torch.meshgrid(y_range, x_range)
points = torch.stack(
(x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
return points
def fcos_target(self, points, gt_bboxes_list, gt_labels_list):
assert len(points) == len(self.regress_ranges)
num_levels = len(points)
# expand regress ranges to align with points
expanded_regress_ranges = [
points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
points[i]) for i in range(num_levels)
]
# concat all levels points and regress ranges
concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
concat_points = torch.cat(points, dim=0)
# get labels and bbox_targets of each image
labels_list, bbox_targets_list = multi_apply(
self.fcos_target_single,
gt_bboxes_list,
gt_labels_list,
points=concat_points,
regress_ranges=concat_regress_ranges)
# split to per img, per level
num_points = [center.size(0) for center in points]
labels_list = [labels.split(num_points, 0) for labels in labels_list]
bbox_targets_list = [
bbox_targets.split(num_points, 0)
for bbox_targets in bbox_targets_list
]
# concat per level image
concat_lvl_labels = []
concat_lvl_bbox_targets = []
for i in range(num_levels):
concat_lvl_labels.append(
torch.cat([labels[i] for labels in labels_list]))
concat_lvl_bbox_targets.append(
torch.cat(
[bbox_targets[i] for bbox_targets in bbox_targets_list]))
return concat_lvl_labels, concat_lvl_bbox_targets
def fcos_target_single(self, gt_bboxes, gt_labels, points, regress_ranges):
num_points = points.size(0)
num_gts = gt_labels.size(0)
if num_gts == 0:
return gt_labels.new_zeros(num_points), \
gt_bboxes.new_zeros((num_points, 4))
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
# TODO: figure out why these two are different
# areas = areas[None].expand(num_points, num_gts)
areas = areas[None].repeat(num_points, 1)
regress_ranges = regress_ranges[:, None, :].expand(
num_points, num_gts, 2)
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
xs, ys = points[:, 0], points[:, 1]
xs = xs[:, None].expand(num_points, num_gts)
ys = ys[:, None].expand(num_points, num_gts)
left = xs - gt_bboxes[..., 0]
right = gt_bboxes[..., 2] - xs
top = ys - gt_bboxes[..., 1]
bottom = gt_bboxes[..., 3] - ys
bbox_targets = torch.stack((left, top, right, bottom), -1)
# condition1: inside a gt bbox
inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
# condition2: limit the regression range for each location
max_regress_distance = bbox_targets.max(-1)[0]
inside_regress_range = (
max_regress_distance >= regress_ranges[..., 0]) & (
max_regress_distance <= regress_ranges[..., 1])
# if there are still more than one objects for a location,
# we choose the one with minimal area
areas[inside_gt_bbox_mask == 0] = INF
areas[inside_regress_range == 0] = INF
min_area, min_area_inds = areas.min(dim=1)
labels = gt_labels[min_area_inds]
labels[min_area == INF] = 0
bbox_targets = bbox_targets[range(num_points), min_area_inds]
return labels, bbox_targets
def centerness_target(self, pos_bbox_targets):
# only calculate pos centerness targets, otherwise there may be nan
left_right = pos_bbox_targets[:, [0, 2]]
top_bottom = pos_bbox_targets[:, [1, 3]]
centerness_targets = (
left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness_targets)
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import multi_apply, multiclass_nms
from mmdet.ops import DeformConv
from ..builder import build_loss
from ..registry import HEADS
from ..utils import ConvModule, bias_init_with_prob
INF = 1e8
class FeatureAlign(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
deformable_groups=4):
super(FeatureAlign, self).__init__()
offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(
4, deformable_groups * offset_channels, 1, bias=False)
self.conv_adaption = DeformConv(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
deformable_groups=deformable_groups)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
normal_init(self.conv_offset, std=0.1)
normal_init(self.conv_adaption, std=0.01)
def forward(self, x, shape):
offset = self.conv_offset(shape)
x = self.relu(self.conv_adaption(x, offset))
return x
@HEADS.register_module
class FoveaHead(nn.Module):
"""FoveaBox: Beyond Anchor-based Object Detector
https://arxiv.org/abs/1904.03797
"""
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128,
512)),
sigma=0.4,
with_deform=False,
deformable_groups=4,
loss_cls=None,
loss_bbox=None,
conv_cfg=None,
norm_cfg=None):
super(FoveaHead, self).__init__()
self.num_classes = num_classes
self.cls_out_channels = num_classes - 1
self.in_channels = in_channels
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.sigma = sigma
self.with_deform = with_deform
self.deformable_groups = deformable_groups
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self):
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
# box branch
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.fovea_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
# cls branch
if not self.with_deform:
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.fovea_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
else:
self.cls_convs.append(
ConvModule(
self.feat_channels, (self.feat_channels * 4),
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.cls_convs.append(
ConvModule((self.feat_channels * 4), (self.feat_channels * 4),
1,
stride=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
bias=self.norm_cfg is None))
self.feature_adaption = FeatureAlign(
self.feat_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
self.fovea_cls = nn.Conv2d(
int(self.feat_channels * 4),
self.cls_out_channels,
3,
padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.fovea_cls, std=0.01, bias=bias_cls)
normal_init(self.fovea_reg, std=0.01)
if self.with_deform:
self.feature_adaption.init_weights()
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
bbox_pred = self.fovea_reg(reg_feat)
if self.with_deform:
cls_feat = self.feature_adaption(cls_feat, bbox_pred.exp())
for cls_layer in self.cls_convs:
cls_feat = cls_layer(cls_feat)
cls_score = self.fovea_cls(cls_feat)
return cls_score, bbox_pred
def get_points(self, featmap_sizes, dtype, device, flatten=False):
points = []
for featmap_size in featmap_sizes:
x_range = torch.arange(
featmap_size[1], dtype=dtype, device=device) + 0.5
y_range = torch.arange(
featmap_size[0], dtype=dtype, device=device) + 0.5
y, x = torch.meshgrid(y_range, x_range)
if flatten:
points.append((y.flatten(), x.flatten()))
else:
points.append((y, x))
return points
def loss(self,
cls_scores,
bbox_preds,
gt_bbox_list,
gt_label_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
assert len(cls_scores) == len(bbox_preds)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
num_imgs = cls_scores[0].size(0)
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_labels, flatten_bbox_targets = self.fovea_target(
gt_bbox_list, gt_label_list, featmap_sizes, points)
pos_inds = (flatten_labels > 0).nonzero().view(-1)
num_pos = len(pos_inds)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs)
if num_pos > 0:
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_weights = pos_bbox_targets.new_zeros(
pos_bbox_targets.size()) + 1.0
loss_bbox = self.loss_bbox(
pos_bbox_preds,
pos_bbox_targets,
pos_weights,
avg_factor=num_pos)
else:
loss_bbox = torch.tensor([0],
dtype=flatten_bbox_preds.dtype,
device=flatten_bbox_preds.device)
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
def fovea_target(self, gt_bbox_list, gt_label_list, featmap_sizes, points):
label_list, bbox_target_list = multi_apply(
self.fovea_target_single,
gt_bbox_list,
gt_label_list,
featmap_size_list=featmap_sizes,
point_list=points)
flatten_labels = [
torch.cat([
labels_level_img.flatten() for labels_level_img in labels_level
]) for labels_level in zip(*label_list)
]
flatten_bbox_targets = [
torch.cat([
bbox_targets_level_img.reshape(-1, 4)
for bbox_targets_level_img in bbox_targets_level
]) for bbox_targets_level in zip(*bbox_target_list)
]
flatten_labels = torch.cat(flatten_labels)
flatten_bbox_targets = torch.cat(flatten_bbox_targets)
return flatten_labels, flatten_bbox_targets
def fovea_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
featmap_size_list=None,
point_list=None):
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
(gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
label_list = []
bbox_target_list = []
# for each pyramid, find the cls and box target
for base_len, (lower_bound, upper_bound), stride, featmap_size, \
(y, x) in zip(self.base_edge_list, self.scale_ranges,
self.strides, featmap_size_list, point_list):
labels = gt_labels_raw.new_zeros(featmap_size)
bbox_targets = gt_bboxes_raw.new(featmap_size[0], featmap_size[1],
4) + 1
# scale assignment
hit_indices = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero().flatten()
if len(hit_indices) == 0:
label_list.append(labels)
bbox_target_list.append(torch.log(bbox_targets))
continue
_, hit_index_order = torch.sort(-gt_areas[hit_indices])
hit_indices = hit_indices[hit_index_order]
gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride
gt_labels = gt_labels_raw[hit_indices]
half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0])
half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
# valid fovea area: left, right, top, down
pos_left = torch.ceil(
gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long().\
clamp(0, featmap_size[1] - 1)
pos_right = torch.floor(
gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long().\
clamp(0, featmap_size[1] - 1)
pos_top = torch.ceil(
gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long().\
clamp(0, featmap_size[0] - 1)
pos_down = torch.floor(
gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long().\
clamp(0, featmap_size[0] - 1)
for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
gt_bboxes_raw[hit_indices, :]):
labels[py1:py2 + 1, px1:px2 + 1] = label
bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
(stride * x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
(stride * y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
(gt_x2 - stride * x[py1:py2 + 1, px1:px2 + 1]) / base_len
bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
(gt_y2 - stride * y[py1:py2 + 1, px1:px2 + 1]) / base_len
bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)
label_list.append(labels)
bbox_target_list.append(torch.log(bbox_targets))
return label_list, bbox_target_list
def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg, rescale=None):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points = self.get_points(
featmap_sizes,
bbox_preds[0].dtype,
bbox_preds[0].device,
flatten=True)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
det_bboxes = self.get_bboxes_single(cls_score_list, bbox_pred_list,
featmap_sizes, points,
img_shape, scale_factor, cfg,
rescale)
result_list.append(det_bboxes)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
featmap_sizes,
point_list,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(point_list)
det_bboxes = []
det_scores = []
for cls_score, bbox_pred, featmap_size, stride, base_len, (y, x) \
in zip(cls_scores, bbox_preds, featmap_sizes, self.strides,
self.base_edge_list, point_list):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4).exp()
nms_pre = cfg.get('nms_pre', -1)
if (nms_pre > 0) and (scores.shape[0] > nms_pre):
max_scores, _ = scores.max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
y = y[topk_inds]
x = x[topk_inds]
x1 = (stride * x - base_len * bbox_pred[:, 0]).\
clamp(min=0, max=img_shape[1] - 1)
y1 = (stride * y - base_len * bbox_pred[:, 1]).\
clamp(min=0, max=img_shape[0] - 1)
x2 = (stride * x + base_len * bbox_pred[:, 2]).\
clamp(min=0, max=img_shape[1] - 1)
y2 = (stride * y + base_len * bbox_pred[:, 3]).\
clamp(min=0, max=img_shape[0] - 1)
bboxes = torch.stack([x1, y1, x2, y2], -1)
det_bboxes.append(bboxes)
det_scores.append(scores)
det_bboxes = torch.cat(det_bboxes)
if rescale:
det_bboxes /= det_bboxes.new_tensor(scale_factor)
det_scores = torch.cat(det_scores)
padding = det_scores.new_zeros(det_scores.shape[0], 1)
det_scores = torch.cat([padding, det_scores], dim=1)
det_bboxes, det_labels = multiclass_nms(det_bboxes, det_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
import torch
import torch.nn.functional as F
from mmdet.core import bbox2delta, bbox_overlaps, delta2bbox
from ..registry import HEADS
from .retina_head import RetinaHead
@HEADS.register_module
class FreeAnchorRetinaHead(RetinaHead):
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
conv_cfg=None,
norm_cfg=None,
pre_anchor_topk=50,
bbox_thr=0.6,
gamma=2.0,
alpha=0.5,
**kwargs):
super(FreeAnchorRetinaHead,
self).__init__(num_classes, in_channels, stacked_convs,
octave_base_scale, scales_per_octave, conv_cfg,
norm_cfg, **kwargs)
self.pre_anchor_topk = pre_anchor_topk
self.bbox_thr = bbox_thr
self.gamma = gamma
self.alpha = alpha
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, _ = self.get_anchors(featmap_sizes, img_metas)
anchors = [torch.cat(anchor) for anchor in anchor_list]
# concatenate each level
cls_scores = [
cls.permute(0, 2, 3,
1).reshape(cls.size(0), -1, self.cls_out_channels)
for cls in cls_scores
]
bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
for bbox_pred in bbox_preds
]
cls_scores = torch.cat(cls_scores, dim=1)
bbox_preds = torch.cat(bbox_preds, dim=1)
cls_prob = torch.sigmoid(cls_scores)
box_prob = []
num_pos = 0
positive_losses = []
for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
bbox_preds_) in enumerate(
zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):
gt_labels_ -= 1
with torch.no_grad():
# box_localization: a_{j}^{loc}, shape: [j, 4]
pred_boxes = delta2bbox(anchors_, bbox_preds_,
self.target_means, self.target_stds)
# object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)
# object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
t1 = self.bbox_thr
t2 = object_box_iou.max(
dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
object_box_prob = ((object_box_iou - t1) / (t2 - t1)).clamp(
min=0, max=1)
# object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
num_obj = gt_labels_.size(0)
indices = torch.stack(
[torch.arange(num_obj).type_as(gt_labels_), gt_labels_],
dim=0)
object_cls_box_prob = torch.sparse_coo_tensor(
indices, object_box_prob)
# image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
"""
from "start" to "end" implement:
image_box_iou = torch.sparse.max(object_cls_box_prob,
dim=0).t()
"""
# start
box_cls_prob = torch.sparse.sum(
object_cls_box_prob, dim=0).to_dense()
indices = torch.nonzero(box_cls_prob).t_()
if indices.numel() == 0:
image_box_prob = torch.zeros(
anchors_.size(0),
self.cls_out_channels).type_as(object_box_prob)
else:
nonzero_box_prob = torch.where(
(gt_labels_.unsqueeze(dim=-1) == indices[0]),
object_box_prob[:, indices[1]],
torch.tensor(
[0]).type_as(object_box_prob)).max(dim=0).values
# upmap to shape [j, c]
image_box_prob = torch.sparse_coo_tensor(
indices.flip([0]),
nonzero_box_prob,
size=(anchors_.size(0),
self.cls_out_channels)).to_dense()
# end
box_prob.append(image_box_prob)
# construct bags for objects
match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)
_, matched = torch.topk(
match_quality_matrix,
self.pre_anchor_topk,
dim=1,
sorted=False)
del match_quality_matrix
# matched_cls_prob: P_{ij}^{cls}
matched_cls_prob = torch.gather(
cls_prob_[matched], 2,
gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
1)).squeeze(2)
# matched_box_prob: P_{ij}^{loc}
matched_anchors = anchors_[matched]
matched_object_targets = bbox2delta(
matched_anchors,
gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors),
self.target_means, self.target_stds)
loss_bbox = self.loss_bbox(
bbox_preds_[matched],
matched_object_targets,
reduction_override='none').sum(-1)
matched_box_prob = torch.exp(-loss_bbox)
# positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
num_pos += len(gt_bboxes_)
positive_losses.append(
self.positive_bag_loss(matched_cls_prob, matched_box_prob))
positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)
# box_prob: P{a_{j} \in A_{+}}
box_prob = torch.stack(box_prob, dim=0)
# negative_loss:
# \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
1, num_pos * self.pre_anchor_topk)
losses = {
'positive_bag_loss': positive_loss,
'negative_bag_loss': negative_loss
}
return losses
def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
# bag_prob = Mean-max(matched_prob)
matched_prob = matched_cls_prob * matched_box_prob
weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
weight /= weight.sum(dim=1).unsqueeze(dim=-1)
bag_prob = (weight * matched_prob).sum(dim=1)
# positive_bag_loss = -self.alpha * log(bag_prob)
return self.alpha * F.binary_cross_entropy(
bag_prob, torch.ones_like(bag_prob), reduction='none')
def negative_bag_loss(self, cls_prob, box_prob):
prob = cls_prob * (1 - box_prob)
negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
prob, torch.zeros_like(prob), reduction='none')
return (1 - self.alpha) * negative_bag_loss
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.ops import MaskedConv2d
from ..registry import HEADS
from ..utils import ConvModule, bias_init_with_prob
from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
@HEADS.register_module
class GARetinaHead(GuidedAnchorHead):
"""Guided-Anchor-based RetinaNet head."""
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
conv_cfg=None,
norm_cfg=None,
**kwargs):
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
super(GARetinaHead, self).__init__(num_classes, in_channels, **kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
1)
self.feature_adaption_cls = FeatureAdaption(
self.feat_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
self.feature_adaption_reg = FeatureAdaption(
self.feat_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
self.retina_cls = MaskedConv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = MaskedConv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
self.feature_adaption_cls.init_weights()
self.feature_adaption_reg.init_weights()
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_loc, std=0.01, bias=bias_cls)
normal_init(self.conv_shape, std=0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
loc_pred = self.conv_loc(cls_feat)
shape_pred = self.conv_shape(reg_feat)
cls_feat = self.feature_adaption_cls(cls_feat, shape_pred)
reg_feat = self.feature_adaption_reg(reg_feat, shape_pred)
if not self.training:
mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
else:
mask = None
cls_score = self.retina_cls(cls_feat, mask)
bbox_pred = self.retina_reg(reg_feat, mask)
return cls_score, bbox_pred, shape_pred, loc_pred
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from ..registry import HEADS
from .guided_anchor_head import GuidedAnchorHead
@HEADS.register_module
class GARPNHead(GuidedAnchorHead):
"""Guided-Anchor-based RPN head."""
def __init__(self, in_channels, **kwargs):
super(GARPNHead, self).__init__(2, in_channels, **kwargs)
def _init_layers(self):
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
super(GARPNHead, self)._init_layers()
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
super(GARPNHead, self).init_weights()
def forward_single(self, x):
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
(cls_score, bbox_pred, shape_pred,
loc_pred) = super(GARPNHead, self).forward_single(x)
return cls_score, bbox_pred, shape_pred, loc_pred
def loss(self,
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore=None):
losses = super(GARPNHead, self).loss(
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
gt_bboxes,
None,
img_metas,
cfg,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'],
loss_rpn_bbox=losses['loss_bbox'],
loss_anchor_shape=losses['loss_shape'],
loss_anchor_loc=losses['loss_loc'])
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
mlvl_masks,
img_shape,
scale_factor,
cfg,
rescale=False):
mlvl_proposals = []
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
anchors = mlvl_anchors[idx]
mask = mlvl_masks[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
# if no location is kept, end.
if mask.sum() == 0:
continue
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
scores = rpn_cls_score.softmax(dim=1)[:, 1]
# filter scores, bbox_pred w.r.t. mask.
# anchors are filtered in get_anchors() beforehand.
scores = scores[mask]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
4)[mask, :]
if scores.dim() == 0:
rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
anchors = anchors.unsqueeze(0)
scores = scores.unsqueeze(0)
# filter anchors, bbox_pred, scores w.r.t. scores
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]
scores = scores[topk_inds]
# get proposals w.r.t. anchors and rpn_bbox_pred
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
# filter out too small bboxes
if cfg.min_bbox_size > 0:
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
# NMS in current level
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
# NMS across multi levels
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
num = min(cfg.max_num, proposals.shape[0])
_, topk_inds = scores.topk(num)
proposals = proposals[topk_inds, :]
return proposals
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_inside_flags, anchor_target,
delta2bbox, force_fp32, ga_loc_target, ga_shape_target,
multi_apply, multiclass_nms)
from mmdet.ops import DeformConv, MaskedConv2d
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob
from .anchor_head import AnchorHead
class FeatureAdaption(nn.Module):
"""Feature Adaption Module.
Feature Adaption Module is implemented based on DCN v1.
It uses anchor shape prediction rather than feature map to
predict offsets of deformable conv layer.
Args:
in_channels (int): Number of channels in the input feature map.
out_channels (int): Number of channels in the output feature map.
kernel_size (int): Deformable conv kernel size.
deformable_groups (int): Deformable conv group size.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
deformable_groups=4):
super(FeatureAdaption, self).__init__()
offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(
2, deformable_groups * offset_channels, 1, bias=False)
self.conv_adaption = DeformConv(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
deformable_groups=deformable_groups)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
normal_init(self.conv_offset, std=0.1)
normal_init(self.conv_adaption, std=0.01)
def forward(self, x, shape):
offset = self.conv_offset(shape.detach())
x = self.relu(self.conv_adaption(x, offset))
return x
@HEADS.register_module
class GuidedAnchorHead(AnchorHead):
"""Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
This GuidedAnchorHead will predict high-quality feature guided
anchors and locations where anchors will be kept in inference.
There are mainly 3 categories of bounding-boxes.
- Sampled (9) pairs for target assignment. (approxes)
- The square boxes where the predicted anchors are based on.
(squares)
- Guided anchors.
Please refer to https://arxiv.org/abs/1901.03278 for more details.
Args:
num_classes (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of hidden channels.
octave_base_scale (int): Base octave scale of each level of
feature map.
scales_per_octave (int): Number of octave scales in each level of
feature map
octave_ratios (Iterable): octave aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
anchoring_means (Iterable): Mean values of anchoring targets.
anchoring_stds (Iterable): Std values of anchoring targets.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
deformable_groups: (int): Group number of DCN in
FeatureAdaption module.
loc_filter_thr (float): Threshold to filter out unconcerned regions.
loss_loc (dict): Config of location loss.
loss_shape (dict): Config of anchor shape loss.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of bbox regression loss.
"""
def __init__(
self,
num_classes,
in_channels,
feat_channels=256,
octave_base_scale=8,
scales_per_octave=3,
octave_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
anchoring_means=(.0, .0, .0, .0),
anchoring_stds=(1.0, 1.0, 1.0, 1.0),
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
deformable_groups=4,
loc_filter_thr=0.01,
loss_loc=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)): # yapf: disable
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.octave_scales = octave_base_scale * np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
self.approxs_per_octave = len(self.octave_scales) * len(octave_ratios)
self.octave_ratios = octave_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchoring_means = anchoring_means
self.anchoring_stds = anchoring_stds
self.target_means = target_means
self.target_stds = target_stds
self.deformable_groups = deformable_groups
self.loc_filter_thr = loc_filter_thr
self.approx_generators = []
self.square_generators = []
for anchor_base in self.anchor_base_sizes:
# Generators for approxs
self.approx_generators.append(
AnchorGenerator(anchor_base, self.octave_scales,
self.octave_ratios))
# Generators for squares
self.square_generators.append(
AnchorGenerator(anchor_base, [self.octave_base_scale], [1.0]))
# one anchor per location
self.num_anchors = 1
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.cls_focal_loss = loss_cls['type'] in ['FocalLoss']
self.loc_focal_loss = loss_loc['type'] in ['FocalLoss']
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
# build losses
self.loss_loc = build_loss(loss_loc)
self.loss_shape = build_loss(loss_shape)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.fp16_enabled = False
self._init_layers()
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
self.conv_shape = nn.Conv2d(self.in_channels, self.num_anchors * 2, 1)
self.feature_adaption = FeatureAdaption(
self.in_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
self.conv_cls = MaskedConv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels,
1)
self.conv_reg = MaskedConv2d(self.feat_channels, self.num_anchors * 4,
1)
def init_weights(self):
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_loc, std=0.01, bias=bias_cls)
normal_init(self.conv_shape, std=0.01)
self.feature_adaption.init_weights()
def forward_single(self, x):
loc_pred = self.conv_loc(x)
shape_pred = self.conv_shape(x)
x = self.feature_adaption(x, shape_pred)
# masked conv is only used during inference for speed-up
if not self.training:
mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
else:
mask = None
cls_score = self.conv_cls(x, mask)
bbox_pred = self.conv_reg(x, mask)
return cls_score, bbox_pred, shape_pred, loc_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_sampled_approxs(self,
featmap_sizes,
img_metas,
cfg,
device='cuda'):
"""Get sampled approxs and inside flags according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
device (torch.device | str): device for returned tensors
Returns:
tuple: approxes of each image, inside flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# approxes for one time
multi_level_approxs = []
for i in range(num_levels):
approxs = self.approx_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i], device=device)
multi_level_approxs.append(approxs)
approxs_list = [multi_level_approxs for _ in range(num_imgs)]
# for each image, we compute inside flags of multi level approxes
inside_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
multi_level_approxs = approxs_list[img_id]
for i in range(num_levels):
approxs = multi_level_approxs[i]
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.approx_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w),
device=device)
inside_flags_list = []
for i in range(self.approxs_per_octave):
split_valid_flags = flags[i::self.approxs_per_octave]
split_approxs = approxs[i::self.approxs_per_octave, :]
inside_flags = anchor_inside_flags(
split_approxs, split_valid_flags,
img_meta['img_shape'][:2], cfg.allowed_border)
inside_flags_list.append(inside_flags)
# inside_flag for a position is true if any anchor in this
# position is true
inside_flags = (
torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
multi_level_flags.append(inside_flags)
inside_flag_list.append(multi_level_flags)
return approxs_list, inside_flag_list
def get_anchors(self,
featmap_sizes,
shape_preds,
loc_preds,
img_metas,
use_loc_filter=False,
device='cuda'):
"""Get squares according to feature map sizes and guided
anchors.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
shape_preds (list[tensor]): Multi-level shape predictions.
loc_preds (list[tensor]): Multi-level location predictions.
img_metas (list[dict]): Image meta info.
use_loc_filter (bool): Use loc filter or not.
device (torch.device | str): device for returned tensors
Returns:
tuple: square approxs of each image, guided anchors of each image,
loc masks of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# squares for one time
multi_level_squares = []
for i in range(num_levels):
squares = self.square_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i], device=device)
multi_level_squares.append(squares)
squares_list = [multi_level_squares for _ in range(num_imgs)]
# for each image, we compute multi level guided anchors
guided_anchors_list = []
loc_mask_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_guided_anchors = []
multi_level_loc_mask = []
for i in range(num_levels):
squares = squares_list[img_id][i]
shape_pred = shape_preds[i][img_id]
loc_pred = loc_preds[i][img_id]
guided_anchors, loc_mask = self.get_guided_anchors_single(
squares,
shape_pred,
loc_pred,
use_loc_filter=use_loc_filter)
multi_level_guided_anchors.append(guided_anchors)
multi_level_loc_mask.append(loc_mask)
guided_anchors_list.append(multi_level_guided_anchors)
loc_mask_list.append(multi_level_loc_mask)
return squares_list, guided_anchors_list, loc_mask_list
def get_guided_anchors_single(self,
squares,
shape_pred,
loc_pred,
use_loc_filter=False):
"""Get guided anchors and loc masks for a single level.
Args:
square (tensor): Squares of a single level.
shape_pred (tensor): Shape predections of a single level.
loc_pred (tensor): Loc predections of a single level.
use_loc_filter (list[tensor]): Use loc filter or not.
Returns:
tuple: guided anchors, location masks
"""
# calculate location filtering mask
loc_pred = loc_pred.sigmoid().detach()
if use_loc_filter:
loc_mask = loc_pred >= self.loc_filter_thr
else:
loc_mask = loc_pred >= 0.0
mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_anchors)
mask = mask.contiguous().view(-1)
# calculate guided anchors
squares = squares[mask]
anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
-1, 2).detach()[mask]
bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
bbox_deltas[:, 2:] = anchor_deltas
guided_anchors = delta2bbox(
squares,
bbox_deltas,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
return guided_anchors, mask
def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
anchor_weights, anchor_total_num):
shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
bbox_gts = bbox_gts.contiguous().view(-1, 4)
anchor_weights = anchor_weights.contiguous().view(-1, 4)
bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
bbox_deltas[:, 2:] += shape_pred
# filter out negative samples to speed-up weighted_bounded_iou_loss
inds = torch.nonzero(anchor_weights[:, 0] > 0).squeeze(1)
bbox_deltas_ = bbox_deltas[inds]
bbox_anchors_ = bbox_anchors[inds]
bbox_gts_ = bbox_gts[inds]
anchor_weights_ = anchor_weights[inds]
pred_anchors_ = delta2bbox(
bbox_anchors_,
bbox_deltas_,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
loss_shape = self.loss_shape(
pred_anchors_,
bbox_gts_,
anchor_weights_,
avg_factor=anchor_total_num)
return loss_shape
def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor,
cfg):
loss_loc = self.loss_loc(
loc_pred.reshape(-1, 1),
loc_target.reshape(-1, 1).long(),
loc_weight.reshape(-1, 1),
avg_factor=loc_avg_factor)
return loss_loc
@force_fp32(
apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
def loss(self,
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.approx_generators)
device = cls_scores[0].device
# get loc targets
loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
gt_bboxes,
featmap_sizes,
self.octave_base_scale,
self.anchor_strides,
center_ratio=cfg.center_ratio,
ignore_ratio=cfg.ignore_ratio)
# get sampled approxes
approxs_list, inside_flag_list = self.get_sampled_approxs(
featmap_sizes, img_metas, cfg, device=device)
# get squares and guided anchors
squares_list, guided_anchors_list, _ = self.get_anchors(
featmap_sizes, shape_preds, loc_preds, img_metas, device=device)
# get shape targets
sampling = False if not hasattr(cfg, 'ga_sampler') else True
shape_targets = ga_shape_target(
approxs_list,
inside_flag_list,
squares_list,
gt_bboxes,
img_metas,
self.approxs_per_octave,
cfg,
sampling=sampling)
if shape_targets is None:
return None
(bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
anchor_bg_num) = shape_targets
anchor_total_num = (
anchor_fg_num if not sampling else anchor_fg_num + anchor_bg_num)
# get anchor targets
sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(
guided_anchors_list,
inside_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos if self.cls_focal_loss else num_total_pos +
num_total_neg)
# get classification and bbox regression losses
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
# get anchor location loss
losses_loc = []
for i in range(len(loc_preds)):
loss_loc = self.loss_loc_single(
loc_preds[i],
loc_targets[i],
loc_weights[i],
loc_avg_factor=loc_avg_factor,
cfg=cfg)
losses_loc.append(loss_loc)
# get anchor shape loss
losses_shape = []
for i in range(len(shape_preds)):
loss_shape = self.loss_shape_single(
shape_preds[i],
bbox_anchors_list[i],
bbox_gts_list[i],
anchor_weights_list[i],
anchor_total_num=anchor_total_num)
losses_shape.append(loss_shape)
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_shape=losses_shape,
loss_loc=losses_loc)
@force_fp32(
apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
def get_bboxes(self,
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len(
loc_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
device = cls_scores[0].device
# get guided anchors
_, guided_anchors, loc_masks = self.get_anchors(
featmap_sizes,
shape_preds,
loc_preds,
img_metas,
use_loc_filter=not self.training,
device=device)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
guided_anchor_list = [
guided_anchors[img_id][i].detach() for i in range(num_levels)
]
loc_mask_list = [
loc_masks[img_id][i].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
guided_anchor_list,
loc_mask_list, img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
mlvl_masks,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
mlvl_anchors,
mlvl_masks):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
# if no location is kept, end.
if mask.sum() == 0:
continue
# reshape scores and bbox_pred
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
# filter scores, bbox_pred w.r.t. mask.
# anchors are filtered in get_anchors() beforehand.
scores = scores[mask, :]
bbox_pred = bbox_pred[mask, :]
if scores.dim() == 0:
anchors = anchors.unsqueeze(0)
scores = scores.unsqueeze(0)
bbox_pred = bbox_pred.unsqueeze(0)
# filter anchors, bbox_pred, scores w.r.t. scores
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
# multi class NMS
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (PointGenerator, multi_apply, multiclass_nms,
point_target)
from mmdet.ops import DeformConv
from ..builder import build_loss
from ..registry import HEADS
from ..utils import ConvModule, bias_init_with_prob
@HEADS.register_module
class RepPointsHead(nn.Module):
"""RepPoint head.
Args:
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
point_feat_channels (int): Number of channels of points features.
stacked_convs (int): How many conv layers are used.
gradient_mul (float): The multiplier to gradients from
points refinement and recognition.
point_strides (Iterable): points strides.
point_base_scale (int): bbox scale for assigning labels.
loss_cls (dict): Config of classification loss.
loss_bbox_init (dict): Config of initial points loss.
loss_bbox_refine (dict): Config of points loss in refinement.
use_grid_points (bool): If we use bounding box representation, the
reppoints is represented as grid points on the bounding box.
center_init (bool): Whether to use center point assignment.
transform_method (str): The methods to transform RepPoints to bbox.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
point_feat_channels=256,
stacked_convs=3,
num_points=9,
gradient_mul=0.1,
point_strides=[8, 16, 32, 64, 128],
point_base_scale=4,
conv_cfg=None,
norm_cfg=None,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_init=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
loss_bbox_refine=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
use_grid_points=False,
center_init=True,
transform_method='moment',
moment_mul=0.01):
super(RepPointsHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.point_feat_channels = point_feat_channels
self.stacked_convs = stacked_convs
self.num_points = num_points
self.gradient_mul = gradient_mul
self.point_base_scale = point_base_scale
self.point_strides = point_strides
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.sampling = loss_cls['type'] not in ['FocalLoss']
self.loss_cls = build_loss(loss_cls)
self.loss_bbox_init = build_loss(loss_bbox_init)
self.loss_bbox_refine = build_loss(loss_bbox_refine)
self.use_grid_points = use_grid_points
self.center_init = center_init
self.transform_method = transform_method
if self.transform_method == 'moment':
self.moment_transfer = nn.Parameter(
data=torch.zeros(2), requires_grad=True)
self.moment_mul = moment_mul
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
self.point_generators = [PointGenerator() for _ in self.point_strides]
# we use deformable conv to extract points features
self.dcn_kernel = int(np.sqrt(num_points))
self.dcn_pad = int((self.dcn_kernel - 1) / 2)
assert self.dcn_kernel * self.dcn_kernel == num_points, \
"The points number should be a square number."
assert self.dcn_kernel % 2 == 1, \
"The points number should be an odd square number."
dcn_base = np.arange(-self.dcn_pad,
self.dcn_pad + 1).astype(np.float64)
dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
(-1))
self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
self._init_layers()
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
self.reppoints_cls_conv = DeformConv(self.feat_channels,
self.point_feat_channels,
self.dcn_kernel, 1, self.dcn_pad)
self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
self.cls_out_channels, 1, 1, 0)
self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
self.point_feat_channels, 3,
1, 1)
self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
self.reppoints_pts_refine_conv = DeformConv(self.feat_channels,
self.point_feat_channels,
self.dcn_kernel, 1,
self.dcn_pad)
self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.reppoints_cls_conv, std=0.01)
normal_init(self.reppoints_cls_out, std=0.01, bias=bias_cls)
normal_init(self.reppoints_pts_init_conv, std=0.01)
normal_init(self.reppoints_pts_init_out, std=0.01)
normal_init(self.reppoints_pts_refine_conv, std=0.01)
normal_init(self.reppoints_pts_refine_out, std=0.01)
def points2bbox(self, pts, y_first=True):
"""
Converting the points set into bounding box.
:param pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
:param y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
:return: each points set is converting to a bbox [x1, y1, x2, y2].
"""
pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
...]
pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
...]
if self.transform_method == 'minmax':
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'partial_minmax':
pts_y = pts_y[:, :4, ...]
pts_x = pts_x[:, :4, ...]
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'moment':
pts_y_mean = pts_y.mean(dim=1, keepdim=True)
pts_x_mean = pts_x.mean(dim=1, keepdim=True)
pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
moment_transfer = (self.moment_transfer * self.moment_mul) + (
self.moment_transfer.detach() * (1 - self.moment_mul))
moment_width_transfer = moment_transfer[0]
moment_height_transfer = moment_transfer[1]
half_width = pts_x_std * torch.exp(moment_width_transfer)
half_height = pts_y_std * torch.exp(moment_height_transfer)
bbox = torch.cat([
pts_x_mean - half_width, pts_y_mean - half_height,
pts_x_mean + half_width, pts_y_mean + half_height
],
dim=1)
else:
raise NotImplementedError
return bbox
def gen_grid_from_reg(self, reg, previous_boxes):
"""
Base on the previous bboxes and regression values, we compute the
regressed bboxes and generate the grids on the bboxes.
:param reg: the regression value to previous bboxes.
:param previous_boxes: previous bboxes.
:return: generate grids on the regressed bboxes.
"""
b, _, h, w = reg.shape
bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
bwh = (previous_boxes[:, 2:, ...] -
previous_boxes[:, :2, ...]).clamp(min=1e-6)
grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
reg[:, 2:, ...])
grid_wh = bwh * torch.exp(reg[:, 2:, ...])
grid_left = grid_topleft[:, [0], ...]
grid_top = grid_topleft[:, [1], ...]
grid_width = grid_wh[:, [0], ...]
grid_height = grid_wh[:, [1], ...]
intervel = torch.linspace(0., 1., self.dcn_kernel).view(
1, self.dcn_kernel, 1, 1).type_as(reg)
grid_x = grid_left + grid_width * intervel
grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
grid_x = grid_x.view(b, -1, h, w)
grid_y = grid_top + grid_height * intervel
grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
grid_y = grid_y.view(b, -1, h, w)
grid_yx = torch.stack([grid_y, grid_x], dim=2)
grid_yx = grid_yx.view(b, -1, h, w)
regressed_bbox = torch.cat([
grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
], 1)
return grid_yx, regressed_bbox
def forward_single(self, x):
dcn_base_offset = self.dcn_base_offset.type_as(x)
# If we use center_init, the initial reppoints is from center points.
# If we use bounding bbox representation, the initial reppoints is
# from regular grid placed on a pre-defined bbox.
if self.use_grid_points or not self.center_init:
scale = self.point_base_scale / 2
points_init = dcn_base_offset / dcn_base_offset.max() * scale
bbox_init = x.new_tensor([-scale, -scale, scale,
scale]).view(1, 4, 1, 1)
else:
points_init = 0
cls_feat = x
pts_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
pts_feat = reg_conv(pts_feat)
# initialize reppoints
pts_out_init = self.reppoints_pts_init_out(
self.relu(self.reppoints_pts_init_conv(pts_feat)))
if self.use_grid_points:
pts_out_init, bbox_out_init = self.gen_grid_from_reg(
pts_out_init, bbox_init.detach())
else:
pts_out_init = pts_out_init + points_init
# refine and classify reppoints
pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
) + self.gradient_mul * pts_out_init
dcn_offset = pts_out_init_grad_mul - dcn_base_offset
cls_out = self.reppoints_cls_out(
self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
pts_out_refine = self.reppoints_pts_refine_out(
self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
if self.use_grid_points:
pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
pts_out_refine, bbox_out_init.detach())
else:
pts_out_refine = pts_out_refine + pts_out_init.detach()
return cls_out, pts_out_init, pts_out_refine
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_points(self, featmap_sizes, img_metas):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: points of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# points center for one time
multi_level_points = []
for i in range(num_levels):
points = self.point_generators[i].grid_points(
featmap_sizes[i], self.point_strides[i])
multi_level_points.append(points)
points_list = [[point.clone() for point in multi_level_points]
for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level grids
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
point_stride = self.point_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
flags = self.point_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return points_list, valid_flag_list
def centers_to_bboxes(self, point_list):
"""Get bboxes according to center points. Only used in MaxIOUAssigner.
"""
bbox_list = []
for i_img, point in enumerate(point_list):
bbox = []
for i_lvl in range(len(self.point_strides)):
scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
bbox_shift = torch.Tensor([-scale, -scale, scale,
scale]).view(1, 4).type_as(point[0])
bbox_center = torch.cat(
[point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center + bbox_shift)
bbox_list.append(bbox)
return bbox_list
def offset_to_pts(self, center_list, pred_list):
"""Change from point offset to point coordinate.
"""
pts_list = []
for i_lvl in range(len(self.point_strides)):
pts_lvl = []
for i_img in range(len(center_list)):
pts_center = center_list[i_img][i_lvl][:, :2].repeat(
1, self.num_points)
pts_shift = pred_list[i_lvl][i_img]
yx_pts_shift = pts_shift.permute(1, 2, 0).view(
-1, 2 * self.num_points)
y_pts_shift = yx_pts_shift[..., 0::2]
x_pts_shift = yx_pts_shift[..., 1::2]
xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
pts_lvl.append(pts)
pts_lvl = torch.stack(pts_lvl, 0)
pts_list.append(pts_lvl)
return pts_list
def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
label_weights, bbox_gt_init, bbox_weights_init,
bbox_gt_refine, bbox_weights_refine, stride,
num_total_samples_init, num_total_samples_refine):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score,
labels,
label_weights,
avg_factor=num_total_samples_refine)
# points loss
bbox_gt_init = bbox_gt_init.reshape(-1, 4)
bbox_weights_init = bbox_weights_init.reshape(-1, 4)
bbox_pred_init = self.points2bbox(
pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
bbox_pred_refine = self.points2bbox(
pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
normalize_term = self.point_base_scale * stride
loss_pts_init = self.loss_bbox_init(
bbox_pred_init / normalize_term,
bbox_gt_init / normalize_term,
bbox_weights_init,
avg_factor=num_total_samples_init)
loss_pts_refine = self.loss_bbox_refine(
bbox_pred_refine / normalize_term,
bbox_gt_refine / normalize_term,
bbox_weights_refine,
avg_factor=num_total_samples_refine)
return loss_cls, loss_pts_init, loss_pts_refine
def loss(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.point_generators)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
# target for initial stage
center_list, valid_flag_list = self.get_points(featmap_sizes,
img_metas)
pts_coordinate_preds_init = self.offset_to_pts(center_list,
pts_preds_init)
if cfg.init.assigner['type'] == 'PointAssigner':
# Assign target for center list
candidate_list = center_list
else:
# transform center list to bbox list and
# assign target for bbox list
bbox_list = self.centers_to_bboxes(center_list)
candidate_list = bbox_list
cls_reg_targets_init = point_target(
candidate_list,
valid_flag_list,
gt_bboxes,
img_metas,
cfg.init,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
(*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
num_total_pos_init, num_total_neg_init) = cls_reg_targets_init
num_total_samples_init = (
num_total_pos_init +
num_total_neg_init if self.sampling else num_total_pos_init)
# target for refinement stage
center_list, valid_flag_list = self.get_points(featmap_sizes,
img_metas)
pts_coordinate_preds_refine = self.offset_to_pts(
center_list, pts_preds_refine)
bbox_list = []
for i_img, center in enumerate(center_list):
bbox = []
for i_lvl in range(len(pts_preds_refine)):
bbox_preds_init = self.points2bbox(
pts_preds_init[i_lvl].detach())
bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
bbox_center = torch.cat(
[center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center +
bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
bbox_list.append(bbox)
cls_reg_targets_refine = point_target(
bbox_list,
valid_flag_list,
gt_bboxes,
img_metas,
cfg.refine,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
(labels_list, label_weights_list, bbox_gt_list_refine,
candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
num_total_neg_refine) = cls_reg_targets_refine
num_total_samples_refine = (
num_total_pos_refine +
num_total_neg_refine if self.sampling else num_total_pos_refine)
# compute loss
losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
self.loss_single,
cls_scores,
pts_coordinate_preds_init,
pts_coordinate_preds_refine,
labels_list,
label_weights_list,
bbox_gt_list_init,
bbox_weights_list_init,
bbox_gt_list_refine,
bbox_weights_list_refine,
self.point_strides,
num_total_samples_init=num_total_samples_init,
num_total_samples_refine=num_total_samples_refine)
loss_dict_all = {
'loss_cls': losses_cls,
'loss_pts_init': losses_pts_init,
'loss_pts_refine': losses_pts_refine
}
return loss_dict_all
def get_bboxes(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
img_metas,
cfg,
rescale=False,
nms=True):
assert len(cls_scores) == len(pts_preds_refine)
bbox_preds_refine = [
self.points2bbox(pts_pred_refine)
for pts_pred_refine in pts_preds_refine
]
num_levels = len(cls_scores)
mlvl_points = [
self.point_generators[i].grid_points(cls_scores[i].size()[-2:],
self.point_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds_refine[i][img_id].detach()
for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale, nms)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_points,
img_shape,
scale_factor,
cfg,
rescale=False,
nms=True):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_bboxes = []
mlvl_scores = []
for i_lvl, (cls_score, bbox_pred, points) in enumerate(
zip(cls_scores, bbox_preds, mlvl_points)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
points = points[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
bboxes = bbox_pred * self.point_strides[i_lvl] + bbox_pos_center
x1 = bboxes[:, 0].clamp(min=0, max=img_shape[1])
y1 = bboxes[:, 1].clamp(min=0, max=img_shape[0])
x2 = bboxes[:, 2].clamp(min=0, max=img_shape[1])
y2 = bboxes[:, 3].clamp(min=0, max=img_shape[0])
bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
if nms:
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
return mlvl_bboxes, mlvl_scores
import numpy as np
import torch.nn as nn
from mmcv.cnn import normal_init
from ..registry import HEADS
from ..utils import ConvModule, bias_init_with_prob
from .anchor_head import AnchorHead
@HEADS.register_module
class RetinaHead(AnchorHead):
"""
An anchor-based head used in [1]_.
The head contains two subnetworks. The first classifies anchor boxes and
the second regresses deltas for the anchors.
References:
.. [1] https://arxiv.org/pdf/1708.02002.pdf
Example:
>>> import torch
>>> self = RetinaHead(11, 7)
>>> x = torch.rand(1, 7, 32, 32)
>>> cls_score, bbox_pred = self.forward_single(x)
>>> # Each anchor predicts a score for each class except background
>>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
>>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
>>> assert cls_per_anchor == (self.num_classes - 1)
>>> assert box_per_anchor == 4
"""
def __init__(self,
num_classes,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
conv_cfg=None,
norm_cfg=None,
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(RetinaHead, self).__init__(
num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward_single(self, x):
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
return cls_score, bbox_pred
import numpy as np
import torch.nn as nn
from mmcv.cnn import normal_init
from ..registry import HEADS
from ..utils import ConvModule, bias_init_with_prob
from .anchor_head import AnchorHead
@HEADS.register_module
class RetinaSepBNHead(AnchorHead):
""""RetinaHead with separate BN.
In RetinaHead, conv/norm layers are shared across different FPN levels,
while in RetinaSepBNHead, conv layers are shared across different FPN
levels, but BN layers are separated.
"""
def __init__(self,
num_classes,
num_ins,
in_channels,
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
conv_cfg=None,
norm_cfg=None,
**kwargs):
self.stacked_convs = stacked_convs
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.num_ins = num_ins
octave_scales = np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
anchor_scales = octave_scales * octave_base_scale
super(RetinaSepBNHead, self).__init__(
num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.num_ins):
cls_convs = nn.ModuleList()
reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.cls_convs.append(cls_convs)
self.reg_convs.append(reg_convs)
for i in range(self.stacked_convs):
for j in range(1, self.num_ins):
self.cls_convs[j][i].conv = self.cls_convs[0][i].conv
self.reg_convs[j][i].conv = self.reg_convs[0][i].conv
self.retina_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
3,
padding=1)
self.retina_reg = nn.Conv2d(
self.feat_channels, self.num_anchors * 4, 3, padding=1)
def init_weights(self):
for m in self.cls_convs[0]:
normal_init(m.conv, std=0.01)
for m in self.reg_convs[0]:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.retina_cls, std=0.01, bias=bias_cls)
normal_init(self.retina_reg, std=0.01)
def forward(self, feats):
cls_scores = []
bbox_preds = []
for i, x in enumerate(feats):
cls_feat = feats[i]
reg_feat = feats[i]
for cls_conv in self.cls_convs[i]:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs[i]:
reg_feat = reg_conv(reg_feat)
cls_score = self.retina_cls(cls_feat)
bbox_pred = self.retina_reg(reg_feat)
cls_scores.append(cls_score)
bbox_preds.append(bbox_pred)
return cls_scores, bbox_preds
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.core import delta2bbox
from mmdet.ops import nms
from ..registry import HEADS
from .anchor_head import AnchorHead
@HEADS.register_module
class RPNHead(AnchorHead):
def __init__(self, in_channels, **kwargs):
super(RPNHead, self).__init__(2, in_channels, **kwargs)
def _init_layers(self):
self.rpn_conv = nn.Conv2d(
self.in_channels, self.feat_channels, 3, padding=1)
self.rpn_cls = nn.Conv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels, 1)
self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
def init_weights(self):
normal_init(self.rpn_conv, std=0.01)
normal_init(self.rpn_cls, std=0.01)
normal_init(self.rpn_reg, std=0.01)
def forward_single(self, x):
x = self.rpn_conv(x)
x = F.relu(x, inplace=True)
rpn_cls_score = self.rpn_cls(x)
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore=None):
losses = super(RPNHead, self).loss(
cls_scores,
bbox_preds,
gt_bboxes,
None,
img_metas,
cfg,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
mlvl_proposals = []
for idx in range(len(cls_scores)):
rpn_cls_score = cls_scores[idx]
rpn_bbox_pred = bbox_preds[idx]
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
if self.use_sigmoid_cls:
rpn_cls_score = rpn_cls_score.reshape(-1)
scores = rpn_cls_score.sigmoid()
else:
rpn_cls_score = rpn_cls_score.reshape(-1, 2)
scores = rpn_cls_score.softmax(dim=1)[:, 1]
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
anchors = mlvl_anchors[idx]
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
_, topk_inds = scores.topk(cfg.nms_pre)
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
anchors = anchors[topk_inds, :]
scores = scores[topk_inds]
proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means,
self.target_stds, img_shape)
if cfg.min_bbox_size > 0:
w = proposals[:, 2] - proposals[:, 0] + 1
h = proposals[:, 3] - proposals[:, 1] + 1
valid_inds = torch.nonzero((w >= cfg.min_bbox_size) &
(h >= cfg.min_bbox_size)).squeeze()
proposals = proposals[valid_inds, :]
scores = scores[valid_inds]
proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1)
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.nms_post, :]
mlvl_proposals.append(proposals)
proposals = torch.cat(mlvl_proposals, 0)
if cfg.nms_across_levels:
proposals, _ = nms(proposals, cfg.nms_thr)
proposals = proposals[:cfg.max_num, :]
else:
scores = proposals[:, 4]
num = min(cfg.max_num, proposals.shape[0])
_, topk_inds = scores.topk(num)
proposals = proposals[topk_inds, :]
return proposals
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class SOLOHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.4,
num_grids=None,
cate_down_pos=0,
with_deform=False,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None):
super(SOLOHead, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.cate_down_pos = cate_down_pos
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.with_deform = with_deform
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.ins_convs = nn.ModuleList()
self.cate_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.ins_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.solo_ins_list = nn.ModuleList()
for seg_num_grid in self.seg_num_grids:
self.solo_ins_list.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid**2, 1))
self.solo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
def init_weights(self):
for m in self.ins_convs:
normal_init(m.conv, std=0.01)
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
bias_ins = bias_init_with_prob(0.01)
for m in self.solo_ins_list:
normal_init(m, std=0.01, bias=bias_ins)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.solo_cate, std=0.01, bias=bias_cate)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
ins_pred, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred, cate_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_feat = x
cate_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_feat = torch.cat([ins_feat, coord_feat], 1)
for i, ins_layer in enumerate(self.ins_convs):
ins_feat = ins_layer(ins_feat)
ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
ins_pred = self.solo_ins_list[idx](ins_feat)
# cate branch
for i, cate_layer in enumerate(self.cate_convs):
if i == self.cate_down_pos:
seg_num_grid = self.seg_num_grids[idx]
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
cate_feat = cate_layer(cate_feat)
cate_pred = self.solo_cate(cate_feat)
if eval:
ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear')
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return ins_pred, cate_pred
def loss(self,
ins_preds,
cate_preds,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in
ins_preds]
ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
self.solo_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
featmap_sizes=featmap_sizes)
# ins
ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
for ins_labels_level_img, ins_ind_labels_level_img in
zip(ins_labels_level, ins_ind_labels_level)], 0)
for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]
ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
for ins_preds_level_img, ins_ind_labels_level_img in
zip(ins_preds_level, ins_ind_labels_level)], 0)
for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum()
# dice loss
loss_ins = []
for input, target in zip(ins_preds, ins_labels):
if input.size()[0] == 0:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solo_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
featmap_sizes=None):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
for (lower_bound, upper_bound), stride, featmap_size, num_grid \
in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
if len(hit_indices) == 0:
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = stride / 2
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
# ins
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_ind_label[label] = True
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
return ins_label_list, cate_label_list, ins_ind_label_list
def get_seg(self, seg_preds, cate_preds, img_metas, cfg, rescale=None):
assert len(seg_preds) == len(cate_preds)
num_levels = len(cate_preds)
featmap_size = seg_preds[0].size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list = [
seg_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
seg_pred_list = torch.cat(seg_pred_list, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
assert len(cate_preds) == len(seg_preds)
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# process.
inds = (cate_preds > cfg.score_thr)
# category scores.
cate_scores = cate_preds[inds]
if len(cate_scores) == 0:
return None
# category labels.
inds = inds.nonzero()
cate_labels = inds[:, 1]
# strides.
size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
strides = cate_scores.new_ones(size_trans[-1])
n_stage = len(self.seg_num_grids)
strides[:size_trans[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
strides[size_trans[ind_ - 1]:size_trans[ind_]] *= self.strides[ind_]
strides = strides[inds[:, 0]]
# masks.
seg_preds = seg_preds[inds[:, 0]]
seg_masks = seg_preds > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
# filter.
keep = sum_masks > strides
if keep.sum() == 0:
return None
seg_masks = seg_masks[keep, ...]
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
# filter.
keep = cate_scores >= cfg.update_thr
if keep.sum() == 0:
return None
seg_preds = seg_preds[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_preds = F.interpolate(seg_preds.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class SOLOv2Head(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.2,
num_grids=None,
ins_out_channels=64,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None,
use_dcn_in_tower=False,
type_dcn=None):
super(SOLOv2Head, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.ins_out_channels = ins_out_channels
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.stacked_convs = stacked_convs
self.kernel_out_channels = self.ins_out_channels * 1 * 1
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_dcn_in_tower = use_dcn_in_tower
self.type_dcn = type_dcn
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.cate_convs = nn.ModuleList()
self.kernel_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.use_dcn_in_tower:
cfg_conv = dict(type=self.type_dcn)
else:
cfg_conv = self.conv_cfg
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.kernel_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.solo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
self.solo_kernel = nn.Conv2d(
self.seg_feat_channels, self.kernel_out_channels, 3, padding=1)
def init_weights(self):
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
for m in self.kernel_convs:
normal_init(m.conv, std=0.01)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.solo_cate, std=0.01, bias=bias_cate)
normal_init(self.solo_kernel, std=0.01)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
cate_pred, kernel_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return cate_pred, kernel_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_kernel_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-1], device=ins_kernel_feat.device)
y_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-2], device=ins_kernel_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_kernel_feat.shape[0], 1, -1, -1])
x = x.expand([ins_kernel_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1)
# kernel branch
kernel_feat = ins_kernel_feat
seg_num_grid = self.seg_num_grids[idx]
kernel_feat = F.interpolate(kernel_feat, size=seg_num_grid, mode='bilinear')
cate_feat = kernel_feat[:, :-2, :, :]
kernel_feat = kernel_feat.contiguous()
for i, kernel_layer in enumerate(self.kernel_convs):
kernel_feat = kernel_layer(kernel_feat)
kernel_pred = self.solo_kernel(kernel_feat)
# cate branch
cate_feat = cate_feat.contiguous()
for i, cate_layer in enumerate(self.cate_convs):
cate_feat = cate_layer(cate_feat)
cate_pred = self.solo_cate(cate_feat)
if eval:
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return cate_pred, kernel_pred
def loss(self,
cate_preds,
kernel_preds,
ins_pred,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
mask_feat_size = ins_pred.size()[-2:]
ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = multi_apply(
self.solov2_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
mask_feat_size=mask_feat_size)
# ins
ins_labels = [torch.cat([ins_labels_level_img
for ins_labels_level_img in ins_labels_level], 0)
for ins_labels_level in zip(*ins_label_list)]
kernel_preds = [[kernel_preds_level_img.view(kernel_preds_level_img.shape[0], -1)[:, grid_orders_level_img]
for kernel_preds_level_img, grid_orders_level_img in
zip(kernel_preds_level, grid_orders_level)]
for kernel_preds_level, grid_orders_level in zip(kernel_preds, zip(*grid_order_list))]
# generate masks
ins_pred = ins_pred
ins_pred_list = []
for b_kernel_pred in kernel_preds:
b_mask_pred = []
for idx, kernel_pred in enumerate(b_kernel_pred):
if kernel_pred.size()[-1] == 0:
continue
cur_ins_pred = ins_pred[idx, ...]
H, W = cur_ins_pred.shape[-2:]
N, I = kernel_pred.shape
cur_ins_pred = cur_ins_pred.unsqueeze(0)
kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1)
cur_ins_pred = F.conv2d(cur_ins_pred, kernel_pred, stride=1).view(-1, H, W)
b_mask_pred.append(cur_ins_pred)
if len(b_mask_pred) == 0:
b_mask_pred = None
else:
b_mask_pred = torch.cat(b_mask_pred, 0)
ins_pred_list.append(b_mask_pred)
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum()
# dice loss
loss_ins = []
for input, target in zip(ins_pred_list, ins_labels):
if input is None:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solov2_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
mask_feat_size):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
grid_order_list = []
for (lower_bound, upper_bound), stride, num_grid \
in zip(self.scale_ranges, self.strides, self.seg_num_grids):
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
if num_ins == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append([])
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
cur_ins_label = torch.zeros([mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8,
device=device)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append(grid_order)
return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list
def get_seg(self, cate_preds, kernel_preds, seg_pred, img_metas, cfg, rescale=None):
num_levels = len(cate_preds)
featmap_size = seg_pred.size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list = seg_pred[img_id, ...].unsqueeze(0)
kernel_pred_list = [
kernel_preds[i][img_id].permute(1, 2, 0).view(-1, self.kernel_out_channels).detach()
for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
kernel_pred_list = torch.cat(kernel_pred_list, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list, kernel_pred_list,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds,
kernel_preds,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
assert len(cate_preds) == len(kernel_preds)
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# process.
inds = (cate_preds > cfg.score_thr)
cate_scores = cate_preds[inds]
if len(cate_scores) == 0:
return None
# cate_labels & kernel_preds
inds = inds.nonzero()
cate_labels = inds[:, 1]
kernel_preds = kernel_preds[inds[:, 0]]
# trans vector.
size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
strides = kernel_preds.new_ones(size_trans[-1])
n_stage = len(self.seg_num_grids)
strides[:size_trans[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
strides[size_trans[ind_-1]:size_trans[ind_]] *= self.strides[ind_]
strides = strides[inds[:, 0]]
# mask encoding.
I, N = kernel_preds.shape
kernel_preds = kernel_preds.view(I, N, 1, 1)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
# mask.
seg_masks = seg_preds > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
# filter.
keep = sum_masks > strides
if keep.sum() == 0:
return None
seg_masks = seg_masks[keep, ...]
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel,sigma=cfg.sigma, sum_masks=sum_masks)
# filter.
keep = cate_scores >= cfg.update_thr
if keep.sum() == 0:
return None
seg_preds = seg_preds[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_preds = F.interpolate(seg_preds.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
INF = 1e8
def center_of_mass(bitmasks):
_, h, w = bitmasks.size()
ys = torch.arange(0, h, dtype=torch.float32, device=bitmasks.device)
xs = torch.arange(0, w, dtype=torch.float32, device=bitmasks.device)
m00 = bitmasks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6)
m10 = (bitmasks * xs).sum(dim=-1).sum(dim=-1)
m01 = (bitmasks * ys[:, None]).sum(dim=-1).sum(dim=-1)
center_x = m10 / m00
center_y = m01 / m00
return center_x, center_y
def points_nms(heat, kernel=2):
# kernel must be 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=1)
keep = (hmax[:, :, :-1, :-1] == heat).float()
return heat * keep
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1)
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return 1-d
@HEADS.register_module
class SOLOv2LightHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.2,
num_grids=None,
ins_out_channels=64,
stacked_convs=4,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None,
use_dcn_in_tower=False,
type_dcn=None):
super(SOLOv2LightHead, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes - 1
self.ins_out_channels = ins_out_channels
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.stacked_convs = stacked_convs
self.kernel_out_channels = self.ins_out_channels * 1 * 1
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins['loss_weight']
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.use_dcn_in_tower = use_dcn_in_tower
self.type_dcn = type_dcn
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
self.cate_convs = nn.ModuleList()
self.kernel_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if self.use_dcn_in_tower and i == self.stacked_convs - 1:
cfg_conv = dict(type=self.type_dcn)
else:
cfg_conv = self.conv_cfg
chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
self.kernel_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
conv_cfg=cfg_conv,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.solo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
self.solo_kernel = nn.Conv2d(
self.seg_feat_channels, self.kernel_out_channels, 3, padding=1)
def init_weights(self):
for m in self.cate_convs:
normal_init(m.conv, std=0.01)
for m in self.kernel_convs:
normal_init(m.conv, std=0.01)
bias_cate = bias_init_with_prob(0.01)
normal_init(self.solo_cate, std=0.01, bias=bias_cate)
normal_init(self.solo_kernel, std=0.01)
def forward(self, feats, eval=False):
new_feats = self.split_feats(feats)
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
cate_pred, kernel_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return cate_pred, kernel_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_kernel_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-1], device=ins_kernel_feat.device)
y_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-2], device=ins_kernel_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_kernel_feat.shape[0], 1, -1, -1])
x = x.expand([ins_kernel_feat.shape[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
ins_kernel_feat = torch.cat([ins_kernel_feat, coord_feat], 1)
# kernel branch
kernel_feat = ins_kernel_feat
seg_num_grid = self.seg_num_grids[idx]
kernel_feat = F.interpolate(kernel_feat, size=seg_num_grid, mode='bilinear')
cate_feat = kernel_feat[:, :-2, :, :]
kernel_feat = kernel_feat.contiguous()
for i, kernel_layer in enumerate(self.kernel_convs):
kernel_feat = kernel_layer(kernel_feat)
kernel_pred = self.solo_kernel(kernel_feat)
# cate branch
cate_feat = cate_feat.contiguous()
for i, cate_layer in enumerate(self.cate_convs):
cate_feat = cate_layer(cate_feat)
cate_pred = self.solo_cate(cate_feat)
if eval:
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return cate_pred, kernel_pred
def loss(self,
cate_preds,
kernel_preds,
ins_pred,
gt_bbox_list,
gt_label_list,
gt_mask_list,
img_metas,
cfg,
gt_bboxes_ignore=None):
mask_feat_size = ins_pred.size()[-2:]
ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = multi_apply(
self.solov2_target_single,
gt_bbox_list,
gt_label_list,
gt_mask_list,
mask_feat_size=mask_feat_size)
# ins
ins_labels = [torch.cat([ins_labels_level_img
for ins_labels_level_img in ins_labels_level], 0)
for ins_labels_level in zip(*ins_label_list)]
kernel_preds = [[kernel_preds_level_img.view(kernel_preds_level_img.shape[0], -1)[:, grid_orders_level_img]
for kernel_preds_level_img, grid_orders_level_img in
zip(kernel_preds_level, grid_orders_level)]
for kernel_preds_level, grid_orders_level in zip(kernel_preds, zip(*grid_order_list))]
# generate masks
ins_pred = ins_pred
ins_pred_list = []
for b_kernel_pred in kernel_preds:
b_mask_pred = []
for idx, kernel_pred in enumerate(b_kernel_pred):
if kernel_pred.size()[-1] == 0:
continue
cur_ins_pred = ins_pred[idx, ...]
H, W = cur_ins_pred.shape[-2:]
N, I = kernel_pred.shape
cur_ins_pred = cur_ins_pred.unsqueeze(0)
kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1)
cur_ins_pred = F.conv2d(cur_ins_pred, kernel_pred, stride=1).view(-1, H, W)
b_mask_pred.append(cur_ins_pred)
if len(b_mask_pred) == 0:
b_mask_pred = None
else:
b_mask_pred = torch.cat(b_mask_pred, 0)
ins_pred_list.append(b_mask_pred)
ins_ind_labels = [
torch.cat([ins_ind_labels_level_img.flatten()
for ins_ind_labels_level_img in ins_ind_labels_level])
for ins_ind_labels_level in zip(*ins_ind_label_list)
]
flatten_ins_ind_labels = torch.cat(ins_ind_labels)
num_ins = flatten_ins_ind_labels.sum()
# dice loss
loss_ins = []
for input, target in zip(ins_pred_list, ins_labels):
if input is None:
continue
input = torch.sigmoid(input)
loss_ins.append(dice_loss(input, target))
loss_ins = torch.cat(loss_ins).mean()
loss_ins = loss_ins * self.ins_loss_weight
# cate
cate_labels = [
torch.cat([cate_labels_level_img.flatten()
for cate_labels_level_img in cate_labels_level])
for cate_labels_level in zip(*cate_label_list)
]
flatten_cate_labels = torch.cat(cate_labels)
cate_preds = [
cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
for cate_pred in cate_preds
]
flatten_cate_preds = torch.cat(cate_preds)
loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
return dict(
loss_ins=loss_ins,
loss_cate=loss_cate)
def solov2_target_single(self,
gt_bboxes_raw,
gt_labels_raw,
gt_masks_raw,
mask_feat_size):
device = gt_labels_raw[0].device
# ins
gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_label_list = []
cate_label_list = []
ins_ind_label_list = []
grid_order_list = []
for (lower_bound, upper_bound), stride, num_grid \
in zip(self.scale_ranges, self.strides, self.seg_num_grids):
hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)
if num_ins == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append([])
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma
# mass center
gt_masks_pt = torch.from_numpy(gt_masks).to(device=device)
center_ws, center_hs = center_of_mass(gt_masks_pt)
valid_mask_flags = gt_masks_pt.sum(dim=-1).sum(dim=-1) > 0
output_stride = 4
for seg_mask, gt_label, half_h, half_w, center_h, center_w, valid_mask_flag in zip(gt_masks, gt_labels, half_hs, half_ws, center_hs, center_ws, valid_mask_flags):
if not valid_mask_flag:
continue
upsampled_size = (mask_feat_size[0] * 4, mask_feat_size[1] * 4)
coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h-1)
down = min(down_box, coord_h+1)
left = max(coord_w-1, left_box)
right = min(right_box, coord_w+1)
cate_label[top:(down+1), left:(right+1)] = gt_label
seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
seg_mask = torch.from_numpy(seg_mask).to(device=device)
for i in range(top, down+1):
for j in range(left, right+1):
label = int(i * num_grid + j)
cur_ins_label = torch.zeros([mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8,
device=device)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(label)
if len(ins_label) == 0:
ins_label = torch.zeros([0, mask_feat_size[0], mask_feat_size[1]], dtype=torch.uint8, device=device)
else:
ins_label = torch.stack(ins_label, 0)
ins_label_list.append(ins_label)
cate_label_list.append(cate_label)
ins_ind_label_list.append(ins_ind_label)
grid_order_list.append(grid_order)
return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list
def get_seg(self, cate_preds, kernel_preds, seg_pred, img_metas, cfg, rescale=None):
num_levels = len(cate_preds)
featmap_size = seg_pred.size()[-2:]
result_list = []
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
seg_pred_list = seg_pred[img_id, ...].unsqueeze(0)
kernel_pred_list = [
kernel_preds[i][img_id].permute(1, 2, 0).view(-1, self.kernel_out_channels).detach()
for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
ori_shape = img_metas[img_id]['ori_shape']
cate_pred_list = torch.cat(cate_pred_list, dim=0)
kernel_pred_list = torch.cat(kernel_pred_list, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list, kernel_pred_list,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds,
kernel_preds,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
assert len(cate_preds) == len(kernel_preds)
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4)
# process.
inds = (cate_preds > cfg.score_thr)
cate_scores = cate_preds[inds]
if len(cate_scores) == 0:
return None
# cate_labels & kernel_preds
inds = inds.nonzero()
cate_labels = inds[:, 1]
kernel_preds = kernel_preds[inds[:, 0]]
# trans vector.
size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0)
strides = kernel_preds.new_ones(size_trans[-1])
n_stage = len(self.seg_num_grids)
strides[:size_trans[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
strides[size_trans[ind_-1]:size_trans[ind_]] *= self.strides[ind_]
strides = strides[inds[:, 0]]
# mask encoding.
I, N = kernel_preds.shape
kernel_preds = kernel_preds.view(I, N, 1, 1)
seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid()
# mask.
seg_masks = seg_preds > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float()
# filter.
keep = sum_masks > strides
if keep.sum() == 0:
return None
seg_masks = seg_masks[keep, ...]
seg_preds = seg_preds[keep, ...]
sum_masks = sum_masks[keep]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# maskness.
seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_scores
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks = seg_masks[sort_inds, :, :]
seg_preds = seg_preds[sort_inds, :, :]
sum_masks = sum_masks[sort_inds]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel,sigma=cfg.sigma, sum_masks=sum_masks)
# filter.
keep = cate_scores >= cfg.update_thr
if keep.sum() == 0:
return None
seg_preds = seg_preds[keep, :, :]
cate_scores = cate_scores[keep]
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img:
sort_inds = sort_inds[:cfg.max_per_img]
seg_preds = seg_preds[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
seg_preds = F.interpolate(seg_preds.unsqueeze(0),
size=upsampled_size_out,
mode='bilinear')[:, :, :h, :w]
seg_masks = F.interpolate(seg_preds,
size=ori_shape[:2],
mode='bilinear').squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from mmdet.core import AnchorGenerator, anchor_target, multi_apply
from ..losses import smooth_l1_loss
from ..registry import HEADS
from .anchor_head import AnchorHead
# TODO: add loss evaluator for SSD
@HEADS.register_module
class SSDHead(AnchorHead):
def __init__(self,
input_size=300,
num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256),
anchor_strides=(8, 16, 32, 64, 100, 300),
basesize_ratio_range=(0.1, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)):
super(AnchorHead, self).__init__()
self.input_size = input_size
self.num_classes = num_classes
self.in_channels = in_channels
self.cls_out_channels = num_classes
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
reg_convs = []
cls_convs = []
for i in range(len(in_channels)):
reg_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * 4,
kernel_size=3,
padding=1))
cls_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * num_classes,
kernel_size=3,
padding=1))
self.reg_convs = nn.ModuleList(reg_convs)
self.cls_convs = nn.ModuleList(cls_convs)
min_ratio, max_ratio = basesize_ratio_range
min_ratio = int(min_ratio * 100)
max_ratio = int(max_ratio * 100)
step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
min_sizes = []
max_sizes = []
for r in range(int(min_ratio), int(max_ratio) + 1, step):
min_sizes.append(int(input_size * r / 100))
max_sizes.append(int(input_size * (r + step) / 100))
if input_size == 300:
if basesize_ratio_range[0] == 0.15: # SSD300 COCO
min_sizes.insert(0, int(input_size * 7 / 100))
max_sizes.insert(0, int(input_size * 15 / 100))
elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
min_sizes.insert(0, int(input_size * 10 / 100))
max_sizes.insert(0, int(input_size * 20 / 100))
elif input_size == 512:
if basesize_ratio_range[0] == 0.1: # SSD512 COCO
min_sizes.insert(0, int(input_size * 4 / 100))
max_sizes.insert(0, int(input_size * 10 / 100))
elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
min_sizes.insert(0, int(input_size * 7 / 100))
max_sizes.insert(0, int(input_size * 15 / 100))
self.anchor_generators = []
self.anchor_strides = anchor_strides
for k in range(len(anchor_strides)):
base_size = min_sizes[k]
stride = anchor_strides[k]
ctr = ((stride - 1) / 2., (stride - 1) / 2.)
scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
ratios = [1.]
for r in anchor_ratios[k]:
ratios += [1 / r, r] # 4 or 6 ratio
anchor_generator = AnchorGenerator(
base_size, scales, ratios, scale_major=False, ctr=ctr)
indices = list(range(len(ratios)))
indices.insert(1, len(indices))
anchor_generator.base_anchors = torch.index_select(
anchor_generator.base_anchors, 0, torch.LongTensor(indices))
self.anchor_generators.append(anchor_generator)
self.target_means = target_means
self.target_stds = target_stds
self.use_sigmoid_cls = False
self.cls_focal_loss = False
self.fp16_enabled = False
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform', bias=0)
def forward(self, feats):
cls_scores = []
bbox_preds = []
for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
self.cls_convs):
cls_scores.append(cls_conv(feat))
bbox_preds.append(reg_conv(feat))
return cls_scores, bbox_preds
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_total_samples, cfg):
loss_cls_all = F.cross_entropy(
cls_score, labels, reduction='none') * label_weights
pos_inds = (labels > 0).nonzero().view(-1)
neg_inds = (labels == 0).nonzero().view(-1)
num_pos_samples = pos_inds.size(0)
num_neg_samples = cfg.neg_pos_ratio * num_pos_samples
if num_neg_samples > neg_inds.size(0):
num_neg_samples = neg_inds.size(0)
topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
loss_cls_pos = loss_cls_all[pos_inds].sum()
loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
loss_bbox = smooth_l1_loss(
bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_total_samples)
return loss_cls[None], loss_bbox
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=1,
sampling=False,
unmap_outputs=False)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_images = len(img_metas)
all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list,
-1).view(num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list,
-2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list,
-2).view(num_images, -1, 4)
# check NaN and Inf
assert torch.isfinite(all_cls_scores).all().item(), \
'classification scores become infinite or NaN!'
assert torch.isfinite(all_bbox_preds).all().item(), \
'bbox predications become infinite or NaN!'
losses_cls, losses_bbox = multi_apply(
self.loss_single,
all_cls_scores,
all_bbox_preds,
all_labels,
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_total_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
from .hrnet import HRNet
from .resnet import ResNet, make_res_layer
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
__all__ = ['ResNet', 'make_res_layer', 'ResNeXt', 'SSDVGG', 'HRNet']
import torch.nn as nn
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.utils import get_root_logger
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
from .resnet import BasicBlock, Bottleneck
class HRModule(nn.Module):
""" High-Resolution Module for HRNet. In this module, every branch
has 4 BasicBlocks/Bottlenecks. Fusion/Exchange is in this module.
"""
def __init__(self,
num_branches,
blocks,
num_blocks,
in_channels,
num_channels,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(HRModule, self).__init__()
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)
self.in_channels = in_channels
self.num_branches = num_branches
self.multiscale_output = multiscale_output
self.norm_cfg = norm_cfg
self.conv_cfg = conv_cfg
self.with_cp = with_cp
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
raise ValueError(error_msg)
if num_branches != len(in_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(in_channels))
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.in_channels[branch_index] != \
num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
self.in_channels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
block.expansion)[1])
layers = []
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(
self.in_channels[branch_index],
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
in_channels = self.in_channels
fuse_layers = []
num_out_branches = num_branches if self.multiscale_output else 1
for i in range(num_out_branches):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=1,
stride=1,
padding=0,
bias=False),
build_norm_layer(self.norm_cfg, in_channels[i])[1],
nn.Upsample(
scale_factor=2**(j - i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
else:
conv_downsamples = []
for k in range(i - j):
if k == i - j - 1:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[i],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[i])[1]))
else:
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels[j],
in_channels[j],
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
in_channels[j])[1],
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv_downsamples))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = 0
for j in range(self.num_branches):
if i == j:
y += x[j]
else:
y += self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
@BACKBONES.register_module
class HRNet(nn.Module):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
arXiv: https://arxiv.org/abs/1904.04514
Args:
extra (dict): detailed configuration for each stage of HRNet.
in_channels (int): Number of input image channels. Normally 3.
conv_cfg (dict): dictionary to construct and config conv layer.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import HRNet
>>> import torch
>>> extra = dict(
>>> stage1=dict(
>>> num_modules=1,
>>> num_branches=1,
>>> block='BOTTLENECK',
>>> num_blocks=(4, ),
>>> num_channels=(64, )),
>>> stage2=dict(
>>> num_modules=1,
>>> num_branches=2,
>>> block='BASIC',
>>> num_blocks=(4, 4),
>>> num_channels=(32, 64)),
>>> stage3=dict(
>>> num_modules=4,
>>> num_branches=3,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4),
>>> num_channels=(32, 64, 128)),
>>> stage4=dict(
>>> num_modules=3,
>>> num_branches=4,
>>> block='BASIC',
>>> num_blocks=(4, 4, 4, 4),
>>> num_channels=(32, 64, 128, 256)))
>>> self = HRNet(extra, in_channels=1)
>>> self.eval()
>>> inputs = torch.rand(1, 1, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 32, 8, 8)
(1, 64, 4, 4)
(1, 128, 2, 2)
(1, 256, 1, 1)
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
def __init__(self,
extra,
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
norm_eval=True,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual
# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
64,
64,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
# stage 1
self.stage1_cfg = self.extra['stage1']
num_channels = self.stage1_cfg['num_channels'][0]
block_type = self.stage1_cfg['block']
num_blocks = self.stage1_cfg['num_blocks'][0]
block = self.blocks_dict[block_type]
stage1_out_channels = num_channels * block.expansion
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
# stage 2
self.stage2_cfg = self.extra['stage2']
num_channels = self.stage2_cfg['num_channels']
block_type = self.stage2_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition1 = self._make_transition_layer([stage1_out_channels],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
# stage 3
self.stage3_cfg = self.extra['stage3']
num_channels = self.stage3_cfg['num_channels']
block_type = self.stage3_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
# stage 4
self.stage4_cfg = self.extra['stage4']
num_channels = self.stage4_cfg['num_channels']
block_type = self.stage4_cfg['block']
block = self.blocks_dict[block_type]
num_channels = [channel * block.expansion for channel in num_channels]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
num_channels_pre_layer[i],
num_channels_cur_layer[i],
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg,
num_channels_cur_layer[i])[1],
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv_downsamples = []
for j in range(i + 1 - num_branches_pre):
in_channels = num_channels_pre_layer[-1]
out_channels = num_channels_cur_layer[i] \
if j == i - num_branches_pre else in_channels
conv_downsamples.append(
nn.Sequential(
build_conv_layer(
self.conv_cfg,
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False),
build_norm_layer(self.norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv_downsamples))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
self.conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = []
layers.append(
block(
inplanes,
planes,
stride,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes,
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*layers)
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
num_modules = layer_config['num_modules']
num_branches = layer_config['num_branches']
num_blocks = layer_config['num_blocks']
num_channels = layer_config['num_channels']
block = self.blocks_dict[layer_config['block']]
hr_modules = []
for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
reset_multiscale_output = False
else:
reset_multiscale_output = True
hr_modules.append(
HRModule(
num_branches,
block,
num_blocks,
in_channels,
num_channels,
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
return nn.Sequential(*hr_modules), in_channels
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['num_branches']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['num_branches']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['num_branches']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage4(x_list)
return y_list
def train(self, mode=True):
super(HRNet, self).train(mode)
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmdet.models.plugins import GeneralizedAttention
from mmdet.ops import ContextBlock
from mmdet.utils import get_root_logger
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
gcb=None,
gen_attention=None):
super(BasicBlock, self).__init__()
assert dcn is None, "Not implemented yet."
assert gen_attention is None, "Not implemented yet."
assert gcb is None, "Not implemented yet."
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
assert not with_cp
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
gcb=None,
gen_attention=None):
"""Bottleneck block for ResNet.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__()
assert style in ['pytorch', 'caffe']
assert dcn is None or isinstance(dcn, dict)
assert gcb is None or isinstance(gcb, dict)
assert gen_attention is None or isinstance(gen_attention, dict)
self.inplanes = inplanes
self.planes = planes
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.dcn = dcn
self.with_dcn = dcn is not None
self.gcb = gcb
self.with_gcb = gcb is not None
self.gen_attention = gen_attention
self.with_gen_attention = gen_attention is not None
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
if self.with_dcn:
fallback_on_stride = dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
conv_cfg,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg cannot be None for DCN'
self.conv2 = build_conv_layer(
dcn,
planes,
planes,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
planes,
planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
if self.with_gcb:
gcb_inplanes = planes * self.expansion
self.context_block = ContextBlock(inplanes=gcb_inplanes, **gcb)
# gen_attention
if self.with_gen_attention:
self.gen_attention_block = GeneralizedAttention(
planes, **gen_attention)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
@property
def norm3(self):
return getattr(self, self.norm3_name)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
if self.with_gen_attention:
out = self.gen_attention_block(out)
out = self.conv3(out)
out = self.norm3(out)
if self.with_gcb:
out = self.context_block(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
gcb=None,
gen_attention=None,
gen_attention_blocks=[]):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1],
)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
dilation=dilation,
downsample=downsample,
style=style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
gcb=gcb,
gen_attention=gen_attention if
(0 in gen_attention_blocks) else None))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
dilation=dilation,
style=style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
gcb=gcb,
gen_attention=gen_attention if
(i in gen_attention_blocks) else None))
return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNet(nn.Module):
"""ResNet backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
in_channels=3,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(0, 1, 2, 3),
style='pytorch',
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
dcn=None,
stage_with_dcn=(False, False, False, False),
gcb=None,
stage_with_gcb=(False, False, False, False),
gen_attention=None,
stage_with_gen_attention=((), (), (), ()),
with_cp=False,
zero_init_residual=True):
super(ResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for resnet'.format(depth))
self.depth = depth
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
if dcn is not None:
assert len(stage_with_dcn) == num_stages
self.gen_attention = gen_attention
self.gcb = gcb
self.stage_with_gcb = stage_with_gcb
if gcb is not None:
assert len(stage_with_gcb) == num_stages
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.inplanes = 64
self._make_stem_layer(in_channels)
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
gcb = self.gcb if self.stage_with_gcb[i] else None
planes = 64 * 2**i
res_layer = make_res_layer(
self.block,
self.inplanes,
planes,
num_blocks,
stride=stride,
dilation=dilation,
style=self.style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
gcb=gcb,
gen_attention=gen_attention,
gen_attention_blocks=stage_with_gen_attention[i])
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = self.block.expansion * 64 * 2**(
len(self.stage_blocks) - 1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels):
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
64,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, 'layer{}'.format(i))
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) and hasattr(
m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
import math
import torch.nn as nn
from ..registry import BACKBONES
from ..utils import build_conv_layer, build_norm_layer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottleneck(_Bottleneck):
def __init__(self, inplanes, planes, groups=1, base_width=4, **kwargs):
"""Bottleneck block for ResNeXt.
If style is "pytorch", the stride-two layer is the 3x3 conv layer,
if it is "caffe", the stride-two layer is the first 1x1 conv layer.
"""
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
if groups == 1:
width = self.planes
else:
width = math.floor(self.planes * (base_width / 64)) * groups
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, width, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.planes * self.expansion, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.inplanes,
width,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
else:
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
self.conv2 = build_conv_layer(
self.dcn,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
width,
self.planes * self.expansion,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def make_res_layer(block,
inplanes,
planes,
blocks,
stride=1,
dilation=1,
groups=1,
base_width=4,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
gcb=None):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
build_conv_layer(
conv_cfg,
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(norm_cfg, planes * block.expansion)[1],
)
layers = []
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=stride,
dilation=dilation,
downsample=downsample,
groups=groups,
base_width=base_width,
style=style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
gcb=gcb))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
inplanes=inplanes,
planes=planes,
stride=1,
dilation=dilation,
groups=groups,
base_width=base_width,
style=style,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
dcn=dcn,
gcb=gcb))
return nn.Sequential(*layers)
@BACKBONES.register_module
class ResNeXt(ResNet):
"""ResNeXt backbone.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
num_stages (int): Resnet stages, normally 4.
groups (int): Group of resnext.
base_width (int): Base width of resnext.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_cfg (dict): dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmdet.models import ResNeXt
>>> import torch
>>> self = ResNeXt(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, groups=1, base_width=4, **kwargs):
super(ResNeXt, self).__init__(**kwargs)
self.groups = groups
self.base_width = base_width
self.inplanes = 64
self.res_layers = []
for i, num_blocks in enumerate(self.stage_blocks):
stride = self.strides[i]
dilation = self.dilations[i]
dcn = self.dcn if self.stage_with_dcn[i] else None
gcb = self.gcb if self.stage_with_gcb[i] else None
planes = 64 * 2**i
res_layer = make_res_layer(
self.block,
self.inplanes,
planes,
num_blocks,
stride=stride,
dilation=dilation,
groups=self.groups,
base_width=self.base_width,
style=self.style,
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
dcn=dcn,
gcb=gcb)
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init
from mmcv.runner import load_checkpoint
from mmdet.utils import get_root_logger
from ..registry import BACKBONES
@BACKBONES.register_module
class SSDVGG(VGG):
"""VGG Backbone network for single-shot-detection
Args:
input_size (int): width and height of input, from {300, 512}.
depth (int): Depth of vgg, from {11, 13, 16, 19}.
out_indices (Sequence[int]): Output from which stages.
Example:
>>> self = SSDVGG(input_size=300, depth=11)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 300, 300)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 1024, 19, 19)
(1, 512, 10, 10)
(1, 256, 5, 5)
(1, 256, 3, 3)
(1, 256, 1, 1)
"""
extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
}
def __init__(self,
input_size,
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.):
# TODO: in_channels for mmcv.VGG
super(SSDVGG, self).__init__(
depth,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
self.input_size = input_size
self.features.add_module(
str(len(self.features)),
nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
self.features.add_module(
str(len(self.features)),
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6))
self.features.add_module(
str(len(self.features)), nn.ReLU(inplace=True))
self.features.add_module(
str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1))
self.features.add_module(
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices
self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.features.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')
for m in self.extra.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
constant_init(self.l2_norm, self.l2_norm.scale)
def forward(self, x):
outs = []
for i, layer in enumerate(self.features):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
if self.input_size == 512:
layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))
return nn.Sequential(*layers)
class L2Norm(nn.Module):
def __init__(self, n_dims, scale=20., eps=1e-10):
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale
def forward(self, x):
# normalization layer convert to FP32 in FP16 training
x_float = x.float()
norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return (self.weight[None, :, None, None].float().expand_as(x_float) *
x_float / norm).type_as(x)
from .bbox_head import BBoxHead
from .convfc_bbox_head import ConvFCBBoxHead, SharedFCBBoxHead
from .double_bbox_head import DoubleConvFCBBoxHead
__all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead', 'DoubleConvFCBBoxHead'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment