Commit c88ee7de authored by yhcao6's avatar yhcao6
Browse files

add SSD300

parent 826a5613
# model settings
model = dict(
type='SingleStageDetector',
pretrained='data/vgg_backbone.pth',
backbone=dict(
type='SSDVGG',
input_size=300,
depth=16,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20),
neck=None,
bbox_head=dict(
type='SSDHead',
input_size=300,
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=81,
anchor_strides=(8, 16, 32, 64, 100, 300),
basesize_ratio_range=(0.15, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0),
target_stds=(0.1, 0.1, 0.2, 0.2)))
train_cfg = dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0.,
ignore_iof_thr=-1,
gt_max_assign_all=False),
smoothl1_beta=1.,
allowed_border=-1,
pos_weight=-1,
neg_pos_ratio=3,
debug=False)
test_cfg = dict(
nms=dict(type='nms', iou_thr=0.45),
min_bbox_size=0,
score_thr=0.02,
max_per_img=200)
# model training and testing settings
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[1, 1, 1], to_rgb=True)
data = dict(
imgs_per_gpu=4,
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
img_scale=(300, 300),
img_norm_cfg=img_norm_cfg,
size_divisor=None,
flip_ratio=0.5,
with_mask=False,
with_crowd=False,
with_label=True,
test_mode=False,
extra_aug=dict(
photo_metric_distortion=dict(
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
expand=dict(
mean=img_norm_cfg['mean'],
to_rgb=img_norm_cfg['to_rgb'],
ratio_range=(1, 4)),
random_crop=dict(
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3)),
resize_keep_ratio=False),
times=10),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(300, 300),
img_norm_cfg=img_norm_cfg,
size_divisor=None,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True,
resize_keep_ratio=False),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
img_scale=(300, 300),
img_norm_cfg=img_norm_cfg,
size_divisor=None,
flip_ratio=0,
with_mask=False,
with_label=False,
test_mode=True,
resize_keep_ratio=False))
# optimizer
optimizer = dict(type='SGD', lr=1e-3, momentum=0.9, weight_decay=5e-4)
optimizer_config = dict()
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 12
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/ssd300_coco'
load_from = None
resume_from = None
workflow = [('train', 1)]
...@@ -13,7 +13,8 @@ def anchor_target(anchor_list, ...@@ -13,7 +13,8 @@ def anchor_target(anchor_list,
cfg, cfg,
gt_labels_list=None, gt_labels_list=None,
cls_out_channels=1, cls_out_channels=1,
sampling=True): sampling=True,
unmap_outputs=True):
"""Compute regression and classification targets for anchors. """Compute regression and classification targets for anchors.
Args: Args:
...@@ -54,7 +55,8 @@ def anchor_target(anchor_list, ...@@ -54,7 +55,8 @@ def anchor_target(anchor_list,
target_stds=target_stds, target_stds=target_stds,
cfg=cfg, cfg=cfg,
cls_out_channels=cls_out_channels, cls_out_channels=cls_out_channels,
sampling=sampling) sampling=sampling,
unmap_outputs=unmap_outputs)
# no valid anchors # no valid anchors
if any([labels is None for labels in all_labels]): if any([labels is None for labels in all_labels]):
return None return None
...@@ -94,7 +96,8 @@ def anchor_target_single(flat_anchors, ...@@ -94,7 +96,8 @@ def anchor_target_single(flat_anchors,
target_stds, target_stds,
cfg, cfg,
cls_out_channels=1, cls_out_channels=1,
sampling=True): sampling=True,
unmap_outputs=True):
inside_flags = anchor_inside_flags(flat_anchors, valid_flags, inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2], img_meta['img_shape'][:2],
cfg.allowed_border) cfg.allowed_border)
...@@ -140,14 +143,15 @@ def anchor_target_single(flat_anchors, ...@@ -140,14 +143,15 @@ def anchor_target_single(flat_anchors,
label_weights[neg_inds] = 1.0 label_weights[neg_inds] = 1.0
# map up to original set of anchors # map up to original set of anchors
num_total_anchors = flat_anchors.size(0) if unmap_outputs:
labels = unmap(labels, num_total_anchors, inside_flags) num_total_anchors = flat_anchors.size(0)
label_weights = unmap(label_weights, num_total_anchors, inside_flags) labels = unmap(labels, num_total_anchors, inside_flags)
if cls_out_channels > 1: label_weights = unmap(label_weights, num_total_anchors, inside_flags)
labels, label_weights = expand_binary_labels(labels, label_weights, if cls_out_channels > 1:
cls_out_channels) labels, label_weights = expand_binary_labels(labels, label_weights,
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) cls_out_channels)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags) bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) neg_inds)
......
...@@ -4,9 +4,10 @@ from .loader import GroupSampler, DistributedGroupSampler, build_dataloader ...@@ -4,9 +4,10 @@ from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
from .utils import to_tensor, random_scale, show_ann, get_dataset from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset from .repeat_dataset import RepeatDataset
from .extra_aug import ExtraAugmentation
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann',
'get_dataset', 'ConcatDataset', 'RepeatDataset', 'get_dataset', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation'
] ]
...@@ -8,6 +8,7 @@ from torch.utils.data import Dataset ...@@ -8,6 +8,7 @@ from torch.utils.data import Dataset
from .transforms import (ImageTransform, BboxTransform, MaskTransform, from .transforms import (ImageTransform, BboxTransform, MaskTransform,
Numpy2Tensor) Numpy2Tensor)
from .utils import to_tensor, random_scale from .utils import to_tensor, random_scale
from .extra_aug import ExtraAugmentation
class CustomDataset(Dataset): class CustomDataset(Dataset):
...@@ -44,7 +45,9 @@ class CustomDataset(Dataset): ...@@ -44,7 +45,9 @@ class CustomDataset(Dataset):
with_mask=True, with_mask=True,
with_crowd=True, with_crowd=True,
with_label=True, with_label=True,
test_mode=False): test_mode=False,
extra_aug=None,
resize_keep_ratio=True):
# load annotations (and proposals) # load annotations (and proposals)
self.img_infos = self.load_annotations(ann_file) self.img_infos = self.load_annotations(ann_file)
if proposal_file is not None: if proposal_file is not None:
...@@ -96,6 +99,15 @@ class CustomDataset(Dataset): ...@@ -96,6 +99,15 @@ class CustomDataset(Dataset):
self.mask_transform = MaskTransform() self.mask_transform = MaskTransform()
self.numpy2tensor = Numpy2Tensor() self.numpy2tensor = Numpy2Tensor()
# if use extra augmentation
if extra_aug is not None:
self.extra_aug = ExtraAugmentation(**extra_aug)
else:
self.extra_aug = None
# image rescale if keep ratio
self.resize_keep_ratio = resize_keep_ratio
def __len__(self): def __len__(self):
return len(self.img_infos) return len(self.img_infos)
...@@ -174,11 +186,17 @@ class CustomDataset(Dataset): ...@@ -174,11 +186,17 @@ class CustomDataset(Dataset):
if len(gt_bboxes) == 0: if len(gt_bboxes) == 0:
return None return None
# extra augmentation
if self.extra_aug is not None:
img, gt_bboxes, gt_labels = self.extra_aug(img, gt_bboxes,
gt_labels)
# apply transforms # apply transforms
flip = True if np.random.rand() < self.flip_ratio else False flip = True if np.random.rand() < self.flip_ratio else False
img_scale = random_scale(self.img_scales) # sample a scale img_scale = random_scale(self.img_scales) # sample a scale
img, img_shape, pad_shape, scale_factor = self.img_transform( img, img_shape, pad_shape, scale_factor = self.img_transform(
img, img_scale, flip) img, img_scale, flip, keep_ratio=self.resize_keep_ratio)
img = img.copy()
if self.proposals is not None: if self.proposals is not None:
proposals = self.bbox_transform(proposals, img_shape, scale_factor, proposals = self.bbox_transform(proposals, img_shape, scale_factor,
flip) flip)
...@@ -230,7 +248,7 @@ class CustomDataset(Dataset): ...@@ -230,7 +248,7 @@ class CustomDataset(Dataset):
def prepare_single(img, scale, flip, proposal=None): def prepare_single(img, scale, flip, proposal=None):
_img, img_shape, pad_shape, scale_factor = self.img_transform( _img, img_shape, pad_shape, scale_factor = self.img_transform(
img, scale, flip) img, scale, flip, keep_ratio=self.resize_keep_ratio)
_img = to_tensor(_img) _img = to_tensor(_img)
_img_meta = dict( _img_meta = dict(
ori_shape=(img_info['height'], img_info['width'], 3), ori_shape=(img_info['height'], img_info['width'], 3),
......
import mmcv
import numpy as np
from numpy import random
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
class PhotoMetricDistortion(object):
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, img, boxes, labels):
# random brightness
if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
return img, boxes, labels
class Expand(object):
def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
if to_rgb:
self.mean = mean[::-1]
else:
self.mean = mean
self.min_ratio, self.max_ratio = ratio_range
def __call__(self, img, boxes, labels):
if random.randint(2):
return img, boxes, labels
h, w, c = img.shape
ratio = random.uniform(self.min_ratio, self.max_ratio)
expand_img = np.full((int(h * ratio), int(w * ratio), c),
self.mean).astype(img.dtype)
left = int(random.uniform(0, w * ratio - w))
top = int(random.uniform(0, h * ratio - h))
expand_img[top:top + h, left:left + w] = img
img = expand_img
boxes += np.tile((left, top), 2)
return img, boxes, labels
class RandomCrop(object):
def __init__(self,
min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
min_crop_size=0.3):
# 1: return ori img
self.sample_mode = (1, *min_ious, 0)
self.min_crop_size = min_crop_size
def __call__(self, img, boxes, labels):
h, w, c = img.shape
while True:
mode = random.choice(self.sample_mode)
if mode == 1:
return img, boxes, labels
min_iou = mode
for i in range(50):
new_w = random.uniform(self.min_crop_size * w, w)
new_h = random.uniform(self.min_crop_size * h, h)
# h / w in [0.5, 2]
if new_h / new_w < 0.5 or new_h / new_w > 2:
continue
left = random.uniform(w - new_w)
top = random.uniform(h - new_h)
patch = np.array((int(left), int(top), int(left + new_w),
int(top + new_h)))
overlaps = bbox_overlaps(
patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
if overlaps.min() < min_iou:
continue
# center of boxes should inside the crop img
center = (boxes[:, :2] + boxes[:, 2:]) / 2
mask = (center[:, 0] > patch[0]) * (
center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * (
center[:, 1] < patch[3])
if not mask.any():
continue
boxes = boxes[mask]
labels = labels[mask]
# adjust boxes
img = img[patch[1]:patch[3], patch[0]:patch[2]]
boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
boxes -= np.tile(patch[:2], 2)
return img, boxes, labels
class ExtraAugmentation(object):
def __init__(self,
photo_metric_distortion=None,
expand=None,
random_crop=None):
self.transforms = []
if photo_metric_distortion is not None:
self.transforms.append(
PhotoMetricDistortion(**photo_metric_distortion))
if expand is not None:
self.transforms.append(Expand(**expand))
if random_crop is not None:
self.transforms.append(RandomCrop(**random_crop))
def __call__(self, img, boxes, labels):
img = img.astype(np.float32)
for transform in self.transforms:
img, boxes, labels = transform(img, boxes, labels)
return img, boxes, labels
...@@ -25,8 +25,14 @@ class ImageTransform(object): ...@@ -25,8 +25,14 @@ class ImageTransform(object):
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.size_divisor = size_divisor self.size_divisor = size_divisor
def __call__(self, img, scale, flip=False): def __call__(self, img, scale, flip=False, keep_ratio=True):
img, scale_factor = mmcv.imrescale(img, scale, return_scale=True) if keep_ratio:
img, scale_factor = mmcv.imrescale(img, scale, return_scale=True)
else:
img, w_scale, h_scale = mmcv.imresize(
img, scale, return_scale=True)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
img_shape = img.shape img_shape = img.shape
img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
if flip: if flip:
......
from .resnet import ResNet from .resnet import ResNet
from .ssd_vgg import SSDVGG
__all__ = ['ResNet'] __all__ = ['ResNet', 'SSDVGG']
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (VGG, xavier_init, constant_init, kaiming_init,
normal_init)
from mmcv.runner import load_checkpoint
class SSDVGG(VGG):
extra_setting = {
300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
}
def __init__(self,
input_size,
depth,
with_last_pool=False,
ceil_mode=True,
out_indices=(3, 4),
out_feature_indices=(22, 34),
l2_norm_scale=20.):
super(SSDVGG, self).__init__(
depth,
with_last_pool=with_last_pool,
ceil_mode=ceil_mode,
out_indices=out_indices)
assert input_size in (300, 512)
# TODO: support 512
if input_size == 512:
raise NotImplementedError
self.features.add_module(
str(len(self.features)),
nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
self.features.add_module(
str(len(self.features)),
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6))
self.features.add_module(
str(len(self.features)), nn.ReLU(inplace=True))
self.features.add_module(
str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1))
self.features.add_module(
str(len(self.features)), nn.ReLU(inplace=True))
self.out_feature_indices = out_feature_indices
self.inplanes = 1024
self.extra = self._make_extra_layers(self.extra_setting[input_size])
self.l2_norm = L2Norm(
self.features[out_feature_indices[0] - 1].out_channels,
l2_norm_scale)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.features.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')
for m in self.extra.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
constant_init(self.l2_norm, self.l2_norm.scale)
def forward(self, x):
outs = []
for i, layer in enumerate(self.features):
x = layer(x)
if i in self.out_feature_indices:
outs.append(x)
for i, layer in enumerate(self.extra):
x = F.relu(layer(x), inplace=True)
if i % 2 == 1:
outs.append(x)
outs[0] = self.l2_norm(outs[0])
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _make_extra_layers(self, outplanes):
layers = []
kernel_sizes = (1, 3)
num_layers = 0
outplane = None
for i in range(len(outplanes)):
if self.inplanes == 'S':
self.inplanes = outplane
continue
k = kernel_sizes[num_layers % 2]
if outplanes[i] == 'S':
outplane = outplanes[i + 1]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=2, padding=1)
else:
outplane = outplanes[i]
conv = nn.Conv2d(
self.inplanes, outplane, k, stride=1, padding=0)
layers.append(conv)
self.inplanes = outplanes[i]
num_layers += 1
return nn.Sequential(*layers)
class L2Norm(nn.Module):
def __init__(self, n_dims, scale=20., eps=1e-10):
super(L2Norm, self).__init__()
self.n_dims = n_dims
self.weight = nn.Parameter(torch.Tensor(self.n_dims))
self.eps = eps
self.scale = scale
def forward(self, x):
norm = x.pow(2).sum(1, keepdim=True).sqrt() + self.eps
return self.weight[None, :, None, None].expand_as(x) * x / norm
from .retina_head import RetinaHead from .retina_head import RetinaHead
from .ssd_head import SSDHead
__all__ = ['RetinaHead'] __all__ = ['RetinaHead', 'SSDHead']
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from mmdet.core import (AnchorGenerator, anchor_target, multi_apply,
delta2bbox, weighted_smoothl1, multiclass_nms)
class SSDHead(nn.Module):
def __init__(self,
input_size=300,
in_channels=(512, 1024, 512, 256, 256, 256),
num_classes=81,
anchor_strides=(8, 16, 32, 64, 100, 300),
basesize_ratio_range=(0.1, 0.9),
anchor_ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0)):
super(SSDHead, self).__init__()
# construct head
num_anchors = [len(ratios) * 2 + 2 for ratios in anchor_ratios]
self.in_channels = in_channels
self.num_classes = num_classes
self.cls_out_channels = num_classes
reg_convs = []
cls_convs = []
for i in range(len(in_channels)):
reg_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * 4,
kernel_size=3,
padding=1))
cls_convs.append(
nn.Conv2d(
in_channels[i],
num_anchors[i] * num_classes,
kernel_size=3,
padding=1))
self.reg_convs = nn.ModuleList(reg_convs)
self.cls_convs = nn.ModuleList(cls_convs)
min_ratio, max_ratio = basesize_ratio_range
min_ratio = int(min_ratio * 100)
max_ratio = int(max_ratio * 100)
step = int(np.floor(max_ratio - min_ratio) / (len(in_channels) - 2))
min_sizes = []
max_sizes = []
for r in range(int(min_ratio), int(max_ratio) + 1, step):
min_sizes.append(int(input_size * r / 100))
max_sizes.append(int(input_size * (r + step) / 100))
min_sizes.insert(0, int(input_size * 7 / 100))
max_sizes.insert(0, int(input_size * 15 / 100))
self.anchor_generators = []
self.anchor_strides = anchor_strides
for k in range(len(anchor_strides)):
base_size = min_sizes[k]
scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
ratios = [1.]
for r in anchor_ratios[k]:
ratios += [1 / r, r] # 4 or 6 ratio
anchor_generator = AnchorGenerator(
base_size, scales, ratios, scale_major=False)
indices = list(range(len(ratios)))
indices.insert(1, len(indices))
anchor_generator.base_anchors = torch.index_select(
anchor_generator.base_anchors, 0, torch.LongTensor(indices))
self.anchor_generators.append(anchor_generator)
self.target_means = target_means
self.target_stds = target_stds
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform', bias=0)
def forward(self, feats):
cls_scores = []
bbox_preds = []
for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
self.cls_convs):
cls_scores.append(cls_conv(feat))
bbox_preds.append(reg_conv(feat))
return cls_scores, bbox_preds
def get_anchors(self, featmap_sizes, img_metas):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_anchors.append(anchors)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
def loss_single(self, cls_score, bbox_pred, labels, label_weights,
bbox_targets, bbox_weights, num_pos_samples, cfg):
loss_cls_all = F.cross_entropy(
cls_score, labels, reduction='none') * label_weights
pos_label_inds = (labels > 0).nonzero().view(-1)
neg_label_inds = (labels == 0).nonzero().view(-1)
num_sample_pos = pos_label_inds.size(0)
num_sample_neg = cfg.neg_pos_ratio * num_sample_pos
if num_sample_neg > neg_label_inds.size(0):
num_sample_neg = neg_label_inds.size(0)
topk_loss_cls_neg, topk_loss_cls_neg_inds = \
loss_cls_all[neg_label_inds].topk(num_sample_neg)
loss_cls_pos = loss_cls_all[pos_label_inds].sum()
loss_cls_neg = topk_loss_cls_neg.sum()
loss_cls = (loss_cls_pos + loss_cls_neg) / num_pos_samples
loss_reg = weighted_smoothl1(
bbox_pred,
bbox_targets,
bbox_weights,
beta=cfg.smoothl1_beta,
avg_factor=num_pos_samples)
return loss_cls[None], loss_reg
def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas)
cls_reg_targets = anchor_target(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_labels_list=gt_labels,
cls_out_channels=self.cls_out_channels,
sampling=False,
unmap_outputs=False)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_images = len(img_metas)
all_cls_scores = torch.cat([
s.permute(0, 2, 3, 1).contiguous().view(
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list, -1).view(
num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).contiguous().view(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list, -2).view(
num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list, -2).view(
num_images, -1, 4)
losses_cls, losses_reg = multi_apply(
self.loss_single,
all_cls_scores,
all_bbox_preds,
all_labels,
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_pos_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_reg=losses_reg)
def get_det_bboxes(self,
cls_scores,
bbox_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds)
num_levels = len(cls_scores)
mlvl_anchors = [
self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
self.anchor_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
results = self._get_det_bboxes_single(
cls_score_list, bbox_pred_list, mlvl_anchors, img_shape,
scale_factor, cfg, rescale)
result_list.append(results)
return result_list
def _get_det_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_proposals = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(1, 2, 0).contiguous().view(
-1, self.cls_out_channels)
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).contiguous().view(-1, 4)
proposals = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_proposals.append(proposals)
mlvl_scores.append(scores)
mlvl_proposals = torch.cat(mlvl_proposals)
if rescale:
mlvl_proposals /= mlvl_proposals.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
det_bboxes, det_labels = multiclass_nms(mlvl_proposals, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment