Commit 89022a26 authored by Jiaqi Wang's avatar Jiaqi Wang Committed by Kai Chen
Browse files

Code of CVPR 2019 Paper: Region Proposal by Guided Anchoring (#594)

* add two stage w/o neck and w/ upperneck

* add rpn r50 c4

* update c4 configs

* fix

* config update

* update config

* minor update

* mask rcnn support c4 train and test

* lr fix

* cascade support upper_neck

* add cascade c4 config

* update config

* update

* update res_layer to new interface

* refactoring

* c4 configs update

* refactoring

* update rpn_c4 config

* rename upper_neck as shared_head

* update

* update configs

* update

* update c4 configs

* update according to commits

* update

* add ga rpn

* test bug fix

* test bug fix with loc_filter_thr is large

* update configs

* update configs

* add ga_retinanet

* ga test bug fix

* update configs

* update

* init masked conv

* update

* update masked conv

* update

* support no ga_sampler

* update

* update

* test with masked_conv

* update comment

* fix flake errors

* fix flake 8 errors

* refactor bounded iou loss

* refactor ga_retina_head

* update configs

* refactor masked conv

* fix flake8 error

* refactor guided_anchor_head and ga_rpn_head

* update configs

* use_sigmoid_cls -> cls_sigmoid_loss; use_focal_loss -> cls_focal_loss

* refactoring

* cls_sigmoid_loss -> use_sigmoid_cls

* fix flake8 error

* add some docs

* rename normalize to norm_cfg

* update configs

* add readme

* update ga_faster config

* update readme

* update readme

* rename configs as r50_caffe

* merge master

* refactor guided anchor target

* update readme

* update approx mas iou assigner

* refactor guided anchor target

* update docstring

* refactor ga heads

* fix flake8 error

* update readme

* update model url

* update comments

* refactor get anchors

* update docstring

* not use_loc_filter during training

* add R-101 results

* update to support build loss api

* fix flake8 error

* update readme with x-101 performances

* update readme

* add a link in project readme

* refactor code about ga shape inside flags

* update

* update

* add x101 config files

* add ga_rpn r101 config

* update some comments

* add comments

* add comments

* update comments

* fix flake8 error
parent 3cb84acc
from __future__ import division
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init
from mmdet.core import (AnchorGenerator, anchor_target, anchor_inside_flags,
ga_loc_target, ga_shape_target, delta2bbox,
multi_apply, multiclass_nms)
from mmdet.ops import DeformConv, MaskedConv2d
from ..builder import build_loss
from .anchor_head import AnchorHead
from ..registry import HEADS
from ..utils import bias_init_with_prob
class FeatureAdaption(nn.Module):
"""Feature Adaption Module.
Feature Adaption Module is implemented based on DCN v1.
It uses anchor shape prediction rather than feature map to
predict offsets of deformable conv layer.
Args:
in_channels (int): Number of channels in the input feature map.
out_channels (int): Number of channels in the output feature map.
kernel_size (int): Deformable conv kernel size.
deformable_groups (int): Deformable conv group size.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
deformable_groups=4):
super(FeatureAdaption, self).__init__()
offset_channels = kernel_size * kernel_size * 2
self.conv_offset = nn.Conv2d(2,
deformable_groups * offset_channels,
1,
bias=False)
self.conv_adaption = DeformConv(in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
deformable_groups=deformable_groups)
self.relu = nn.ReLU(inplace=True)
def init_weights(self):
normal_init(self.conv_offset, std=0.1)
normal_init(self.conv_adaption, std=0.01)
def forward(self, x, shape):
offset = self.conv_offset(shape.detach())
x = self.relu(self.conv_adaption(x, offset))
return x
@HEADS.register_module
class GuidedAnchorHead(AnchorHead):
"""Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
This GuidedAnchorHead will predict high-quality feature guided
anchors and locations where anchors will be kept in inference.
There are mainly 3 categories of bounding-boxes.
- Sampled (9) pairs for target assignment. (approxes)
- The square boxes where the predicted anchors are based on.
(squares)
- Guided anchors.
Please refer to https://arxiv.org/abs/1901.03278 for more details.
Args:
num_classes (int): Number of classes.
in_channels (int): Number of channels in the input feature map.
feat_channels (int): Number of channels of the feature map.
octave_base_scale (int): Base octave scale of each level of
feature map.
scales_per_octave (int): Number of octave scales in each level of
feature map
octave_ratios (Iterable): octave aspect ratios.
anchor_strides (Iterable): Anchor strides.
anchor_base_sizes (Iterable): Anchor base sizes.
anchoring_means (Iterable): Mean values of anchoring targets.
anchoring_stds (Iterable): Std values of anchoring targets.
target_means (Iterable): Mean values of regression targets.
target_stds (Iterable): Std values of regression targets.
deformable_groups: (int): Group number of DCN in
FeatureAdaption module.
loc_filter_thr (float): Threshold to filter out unconcerned regions.
loss_loc (dict): Config of location loss.
loss_shape (dict): Config of anchor shape loss.
loss_cls (dict): Config of classification loss.
loss_bbox (dict): Config of bbox regression loss.
"""
def __init__(self,
num_classes,
in_channels,
feat_channels=256,
octave_base_scale=8,
scales_per_octave=3,
octave_ratios=[0.5, 1.0, 2.0],
anchor_strides=[4, 8, 16, 32, 64],
anchor_base_sizes=None,
anchoring_means=(.0, .0, .0, .0),
anchoring_stds=(1.0, 1.0, 1.0, 1.0),
target_means=(.0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0),
deformable_groups=4,
loc_filter_thr=0.01,
loss_loc=dict(type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_shape=dict(type='IoULoss', beta=0.2, loss_weight=1.0),
loss_cls=dict(type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)):
super(AnchorHead, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.feat_channels = feat_channels
self.octave_base_scale = octave_base_scale
self.scales_per_octave = scales_per_octave
self.octave_scales = octave_base_scale * np.array(
[2**(i / scales_per_octave) for i in range(scales_per_octave)])
self.approxs_per_octave = len(self.octave_scales) * len(octave_ratios)
self.octave_ratios = octave_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(
anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
self.anchoring_means = anchoring_means
self.anchoring_stds = anchoring_stds
self.target_means = target_means
self.target_stds = target_stds
self.deformable_groups = deformable_groups
self.loc_filter_thr = loc_filter_thr
self.approx_generators = []
self.square_generators = []
for anchor_base in self.anchor_base_sizes:
# Generators for approxs
self.approx_generators.append(
AnchorGenerator(anchor_base, self.octave_scales,
self.octave_ratios))
# Generators for squares
self.square_generators.append(
AnchorGenerator(anchor_base, [self.octave_base_scale], [1.0]))
# one anchor per location
self.num_anchors = 1
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
self.cls_focal_loss = loss_cls['type'] in ['FocalLoss']
self.loc_focal_loss = loss_loc['type'] in ['FocalLoss']
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes - 1
else:
self.cls_out_channels = self.num_classes
# build losses
self.loss_loc = build_loss(loss_loc)
self.loss_shape = build_loss(loss_shape)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self._init_layers()
def _init_layers(self):
self.relu = nn.ReLU(inplace=True)
self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
1)
self.feature_adaption = FeatureAdaption(
self.feat_channels,
self.feat_channels,
kernel_size=3,
deformable_groups=self.deformable_groups)
self.conv_cls = MaskedConv2d(self.feat_channels,
self.num_anchors * self.cls_out_channels,
1)
self.conv_reg = MaskedConv2d(self.feat_channels, self.num_anchors * 4,
1)
def init_weights(self):
normal_init(self.conv_cls, std=0.01)
normal_init(self.conv_reg, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.conv_loc, std=0.01, bias=bias_cls)
normal_init(self.conv_shape, std=0.01)
self.feature_adaption.init_weights()
def forward_single(self, x):
loc_pred = self.conv_loc(x)
shape_pred = self.conv_shape(x)
x = self.feature_adaption(x, shape_pred)
# masked conv is only used during inference for speed-up
if not self.training:
mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
else:
mask = None
cls_score = self.conv_cls(x, mask)
bbox_pred = self.conv_reg(x, mask)
return cls_score, bbox_pred, shape_pred, loc_pred
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def get_sampled_approxs(self, featmap_sizes, img_metas, cfg):
"""Get sampled approxs and inside flags according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: approxes of each image, inside flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# approxes for one time
multi_level_approxs = []
for i in range(num_levels):
approxs = self.approx_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_approxs.append(approxs)
approxs_list = [multi_level_approxs for _ in range(num_imgs)]
# for each image, we compute inside flags of multi level approxes
inside_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
multi_level_approxs = approxs_list[img_id]
for i in range(num_levels):
approxs = multi_level_approxs[i]
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w, _ = img_meta['pad_shape']
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.approx_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
inside_flags_list = []
for i in range(self.approxs_per_octave):
split_valid_flags = flags[i::self.approxs_per_octave]
split_approxs = approxs[i::self.approxs_per_octave, :]
inside_flags = anchor_inside_flags(
split_approxs, split_valid_flags,
img_meta['img_shape'][:2], cfg.allowed_border)
inside_flags_list.append(inside_flags)
# inside_flag for a position is true if any anchor in this
# position is true
inside_flags = (torch.stack(inside_flags_list, 0).sum(dim=0) >
0)
multi_level_flags.append(inside_flags)
inside_flag_list.append(multi_level_flags)
return approxs_list, inside_flag_list
def get_anchors(self,
featmap_sizes,
shape_preds,
loc_preds,
img_metas,
use_loc_filter=False):
"""Get squares according to feature map sizes and guided
anchors.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
shape_preds (list[tensor]): Multi-level shape predictions.
loc_preds (list[tensor]): Multi-level location predictions.
img_metas (list[dict]): Image meta info.
use_loc_filter (bool): Use loc filter or not.
Returns:
tuple: square approxs of each image, guided anchors of each image,
loc masks of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# squares for one time
multi_level_squares = []
for i in range(num_levels):
squares = self.square_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
multi_level_squares.append(squares)
squares_list = [multi_level_squares for _ in range(num_imgs)]
# for each image, we compute multi level guided anchors
guided_anchors_list = []
loc_mask_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_guided_anchors = []
multi_level_loc_mask = []
for i in range(num_levels):
squares = squares_list[img_id][i]
shape_pred = shape_preds[i][img_id]
loc_pred = loc_preds[i][img_id]
guided_anchors, loc_mask = self.get_guided_anchors_single(
squares,
shape_pred,
loc_pred,
use_loc_filter=use_loc_filter)
multi_level_guided_anchors.append(guided_anchors)
multi_level_loc_mask.append(loc_mask)
guided_anchors_list.append(multi_level_guided_anchors)
loc_mask_list.append(multi_level_loc_mask)
return squares_list, guided_anchors_list, loc_mask_list
def get_guided_anchors_single(self,
squares,
shape_pred,
loc_pred,
use_loc_filter=False):
"""Get guided anchors and loc masks for a single level.
Args:
square (tensor): Squares of a single level.
shape_pred (tensor): Shape predections of a single level.
loc_pred (tensor): Loc predections of a single level.
use_loc_filter (list[tensor]): Use loc filter or not.
Returns:
tuple: guided anchors, location masks
"""
# calculate location filtering mask
loc_pred = loc_pred.sigmoid().detach()
if use_loc_filter:
loc_mask = loc_pred >= self.loc_filter_thr
else:
loc_mask = loc_pred >= 0.0
mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_anchors)
mask = mask.contiguous().view(-1)
# calculate guided anchors
squares = squares[mask]
anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
-1, 2).detach()[mask]
bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
bbox_deltas[:, 2:] = anchor_deltas
guided_anchors = delta2bbox(squares,
bbox_deltas,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
return guided_anchors, mask
def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
anchor_weights, anchor_total_num):
shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
bbox_gts = bbox_gts.contiguous().view(-1, 4)
anchor_weights = anchor_weights.contiguous().view(-1, 4)
bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
bbox_deltas[:, 2:] += shape_pred
# filter out negative samples to speed-up weighted_bounded_iou_loss
inds = torch.nonzero(anchor_weights[:, 0] > 0).squeeze(1)
bbox_deltas_ = bbox_deltas[inds]
bbox_anchors_ = bbox_anchors[inds]
bbox_gts_ = bbox_gts[inds]
anchor_weights_ = anchor_weights[inds]
pred_anchors_ = delta2bbox(bbox_anchors_,
bbox_deltas_,
self.anchoring_means,
self.anchoring_stds,
wh_ratio_clip=1e-6)
loss_shape = self.loss_shape(pred_anchors_,
bbox_gts_,
anchor_weights_,
avg_factor=anchor_total_num)
return loss_shape
def loss_loc_single(self, loc_pred, loc_target, loc_weight, loc_avg_factor,
cfg):
loss_loc = self.loss_loc(loc_pred.reshape(-1, 1),
loc_target.reshape(-1, 1).long(),
loc_weight.reshape(-1, 1),
avg_factor=loc_avg_factor)
return loss_loc
def loss(self,
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.approx_generators)
# get loc targets
loc_targets, loc_weights, loc_avg_factor = ga_loc_target(
gt_bboxes,
featmap_sizes,
self.octave_base_scale,
self.anchor_strides,
center_ratio=cfg.center_ratio,
ignore_ratio=cfg.ignore_ratio)
# get sampled approxes
approxs_list, inside_flag_list = self.get_sampled_approxs(
featmap_sizes, img_metas, cfg)
# get squares and guided anchors
squares_list, guided_anchors_list, _ = self.get_anchors(
featmap_sizes, shape_preds, loc_preds, img_metas)
# get shape targets
sampling = False if not hasattr(cfg, 'ga_sampler') else True
shape_targets = ga_shape_target(approxs_list,
inside_flag_list,
squares_list,
gt_bboxes,
img_metas,
self.approxs_per_octave,
cfg,
sampling=sampling)
if shape_targets is None:
return None
(bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
anchor_bg_num) = shape_targets
anchor_total_num = (anchor_fg_num if not sampling else anchor_fg_num +
anchor_bg_num)
# get anchor targets
sampling = False if self.cls_focal_loss else True
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = anchor_target(guided_anchors_list,
inside_flag_list,
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (num_total_pos if self.cls_focal_loss else
num_total_pos + num_total_neg)
# get classification and bbox regression losses
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
# get anchor location loss
losses_loc, = multi_apply(self.loss_loc_single,
loc_preds,
loc_targets,
loc_weights,
loc_avg_factor=loc_avg_factor,
cfg=cfg)
# get anchor shape loss
losses_shape, = multi_apply(self.loss_shape_single,
shape_preds,
bbox_anchors_list,
bbox_gts_list,
anchor_weights_list,
anchor_total_num=anchor_total_num)
return dict(loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_shape=losses_shape,
loss_loc=losses_loc)
def get_bboxes(self,
cls_scores,
bbox_preds,
shape_preds,
loc_preds,
img_metas,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len(
loc_preds)
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
# get guided anchors
_, guided_anchors, loc_masks = self.get_anchors(
featmap_sizes,
shape_preds,
loc_preds,
img_metas,
use_loc_filter=not self.training)
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
guided_anchor_list = [
guided_anchors[img_id][i].detach() for i in range(num_levels)
]
loc_mask_list = [
loc_masks[img_id][i].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
guided_anchor_list,
loc_mask_list, img_shape,
scale_factor, cfg, rescale)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
mlvl_anchors,
mlvl_masks,
img_shape,
scale_factor,
cfg,
rescale=False):
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
mlvl_anchors,
mlvl_masks):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
# if no location is kept, end.
if mask.sum() == 0:
continue
# reshape scores and bbox_pred
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
# filter scores, bbox_pred w.r.t. mask.
# anchors are filtered in get_anchors() beforehand.
scores = scores[mask, :]
bbox_pred = bbox_pred[mask, :]
if scores.dim() == 0:
anchors = anchors.unsqueeze(0)
scores = scores.unsqueeze(0)
bbox_pred = bbox_pred.unsqueeze(0)
# filter anchors, bbox_pred, scores w.r.t. scores
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, 1:].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
self.target_stds, img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
if self.use_sigmoid_cls:
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
# multi class NMS
det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
from .cross_entropy_loss import CrossEntropyLoss from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import FocalLoss from .focal_loss import FocalLoss
from .smooth_l1_loss import SmoothL1Loss from .smooth_l1_loss import SmoothL1Loss
from .iou_loss import IoULoss
__all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss'] __all__ = ['CrossEntropyLoss', 'FocalLoss', 'SmoothL1Loss', 'IoULoss']
import torch.nn as nn
from mmdet.core import weighted_iou_loss
from ..registry import LOSSES
@LOSSES.register_module
class IoULoss(nn.Module):
def __init__(self, style='naive', beta=0.2, eps=1e-3, loss_weight=1.0):
super(IoULoss, self).__init__()
self.style = style
self.beta = beta
self.eps = eps
self.loss_weight = loss_weight
def forward(self, pred, target, weight, *args, **kwargs):
loss = self.loss_weight * weighted_iou_loss(
pred,
target,
weight,
beta=self.beta,
eps=self.eps,
*args,
**kwargs)
return loss
...@@ -6,11 +6,13 @@ from .nms import nms, soft_nms ...@@ -6,11 +6,13 @@ from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss from .sigmoid_focal_loss import SigmoidFocalLoss, sigmoid_focal_loss
from .masked_conv import MaskedConv2d
__all__ = [ __all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss' 'deform_roi_pooling', 'SigmoidFocalLoss', 'sigmoid_focal_loss',
'MaskedConv2d'
] ]
from .functions.masked_conv import masked_conv2d
from .modules.masked_conv import MaskedConv2d
__all__ = ['masked_conv2d', 'MaskedConv2d']
import math
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import masked_conv2d_cuda
class MaskedConv2dFunction(Function):
@staticmethod
def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
assert mask.dim() == 3 and mask.size(0) == 1
assert features.dim() == 4 and features.size(0) == 1
assert features.size()[2:] == mask.size()[1:]
pad_h, pad_w = _pair(padding)
stride_h, stride_w = _pair(stride)
if stride_h != 1 or stride_w != 1:
raise ValueError(
'Stride could not only be 1 in masked_conv2d currently.')
if not features.is_cuda:
raise NotImplementedError
out_channel, in_channel, kernel_h, kernel_w = weight.size()
batch_size = features.size(0)
out_h = int(
math.floor((features.size(2) + 2 * pad_h -
(kernel_h - 1) - 1) / stride_h + 1))
out_w = int(
math.floor((features.size(3) + 2 * pad_w -
(kernel_h - 1) - 1) / stride_w + 1))
mask_inds = torch.nonzero(mask[0] > 0)
mask_h_idx = mask_inds[:, 0].contiguous()
mask_w_idx = mask_inds[:, 1].contiguous()
data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
mask_inds.size(0))
masked_conv2d_cuda.masked_im2col_forward(features, mask_h_idx,
mask_w_idx, kernel_h,
kernel_w, pad_h, pad_w,
data_col)
masked_output = torch.addmm(1, bias[:, None], 1,
weight.view(out_channel, -1), data_col)
output = features.new_zeros(batch_size, out_channel, out_h, out_w)
masked_conv2d_cuda.masked_col2im_forward(masked_output, mask_h_idx,
mask_w_idx, out_h, out_w,
out_channel, output)
return output
@staticmethod
def backward(ctx, grad_output):
return (None, ) * 5
masked_conv2d = MaskedConv2dFunction.apply
import torch.nn as nn
from ..functions.masked_conv import masked_conv2d
class MaskedConv2d(nn.Conv2d):
"""A MaskedConv2d which inherits the official Conv2d.
The masked forward doesn't implement the backward function and only
supports the stride parameter to be 1 currently.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(MaskedConv2d,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input, mask=None):
if mask is None: # fallback to the normal Conv2d
return super(MaskedConv2d, self).forward(input)
else:
return masked_conv2d(input, mask, self.weight, self.bias,
self.padding)
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='masked_conv2d_cuda',
ext_modules=[
CUDAExtension('masked_conv2d_cuda', [
'src/masked_conv2d_cuda.cpp',
'src/masked_conv2d_kernel.cu',
]),
],
cmdclass={'build_ext': BuildExtension})
#include <torch/extension.h>
#include <cmath>
#include <vector>
int MaskedIm2colForwardLaucher(const at::Tensor im, const int height,
const int width, const int channels,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor col);
int MaskedCol2imForwardLaucher(const at::Tensor col, const int height,
const int width, const int channels,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor im);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
int masked_im2col_forward_cuda(const at::Tensor im, const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int kernel_h,
const int kernel_w, const int pad_h,
const int pad_w, at::Tensor col) {
CHECK_INPUT(im);
CHECK_INPUT(mask_h_idx);
CHECK_INPUT(mask_w_idx);
CHECK_INPUT(col);
// im: (n, ic, h, w), kernel size (kh, kw)
// kernel: (oc, ic * kh * kw), col: (kh * kw * ic, ow * oh)
int channels = im.size(1);
int height = im.size(2);
int width = im.size(3);
int mask_cnt = mask_h_idx.size(0);
MaskedIm2colForwardLaucher(im, height, width, channels, kernel_h, kernel_w,
pad_h, pad_w, mask_h_idx, mask_w_idx, mask_cnt,
col);
return 1;
}
int masked_col2im_forward_cuda(const at::Tensor col,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, int height,
int width, int channels, at::Tensor im) {
CHECK_INPUT(col);
CHECK_INPUT(mask_h_idx);
CHECK_INPUT(mask_w_idx);
CHECK_INPUT(im);
// im: (n, ic, h, w), kernel size (kh, kw)
// kernel: (oc, ic * kh * kh), col: (kh * kw * ic, ow * oh)
int mask_cnt = mask_h_idx.size(0);
MaskedCol2imForwardLaucher(col, height, width, channels, mask_h_idx,
mask_w_idx, mask_cnt, im);
return 1;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("masked_im2col_forward", &masked_im2col_forward_cuda,
"masked_im2col forward (CUDA)");
m.def("masked_col2im_forward", &masked_col2im_forward_cuda,
"masked_col2im forward (CUDA)");
}
\ No newline at end of file
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
#define THREADS_PER_BLOCK 1024
inline int GET_BLOCKS(const int N) {
int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
int max_block_num = 65000;
return min(optimal_block_num, max_block_num);
}
template <typename scalar_t>
__global__ void MaskedIm2colForward(const int n, const scalar_t *data_im,
const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const long *mask_h_idx,
const long *mask_w_idx, const int mask_cnt,
scalar_t *data_col) {
// mask_cnt * channels
CUDA_1D_KERNEL_LOOP(index, n) {
const int m_index = index % mask_cnt;
const int h_col = mask_h_idx[m_index];
const int w_col = mask_w_idx[m_index];
const int c_im = index / mask_cnt;
const int c_col = c_im * kernel_h * kernel_w;
const int h_offset = h_col - pad_h;
const int w_offset = w_col - pad_w;
scalar_t *data_col_ptr = data_col + c_col * mask_cnt + m_index;
for (int i = 0; i < kernel_h; ++i) {
int h_im = h_offset + i;
for (int j = 0; j < kernel_w; ++j) {
int w_im = w_offset + j;
if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
*data_col_ptr =
(scalar_t)data_im[(c_im * height + h_im) * width + w_im];
} else {
*data_col_ptr = 0.0;
}
data_col_ptr += mask_cnt;
}
}
}
}
int MaskedIm2colForwardLaucher(const at::Tensor bottom_data, const int height,
const int width, const int channels,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor top_data) {
const int output_size = mask_cnt * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bottom_data.type(), "MaskedIm2colLaucherForward", ([&] {
const scalar_t *bottom_data_ = bottom_data.data<scalar_t>();
const long *mask_h_idx_ = mask_h_idx.data<long>();
const long *mask_w_idx_ = mask_w_idx.data<long>();
scalar_t *top_data_ = top_data.data<scalar_t>();
MaskedIm2colForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data_, height, width, kernel_h, kernel_w,
pad_h, pad_w, mask_h_idx_, mask_w_idx_, mask_cnt, top_data_);
}));
THCudaCheck(cudaGetLastError());
return 1;
}
template <typename scalar_t>
__global__ void MaskedCol2imForward(const int n, const scalar_t *data_col,
const int height, const int width,
const int channels, const long *mask_h_idx,
const long *mask_w_idx, const int mask_cnt,
scalar_t *data_im) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int m_index = index % mask_cnt;
const int h_im = mask_h_idx[m_index];
const int w_im = mask_w_idx[m_index];
const int c_im = index / mask_cnt;
// int kernel_extent_w = (kernel_w - 1) + 1;
// int kernel_extent_h = (kernel_h - 1) + 1;
// compute the start and end of the output
data_im[(c_im * height + h_im) * width + w_im] = data_col[index];
}
}
int MaskedCol2imForwardLaucher(const at::Tensor bottom_data, const int height,
const int width, const int channels,
const at::Tensor mask_h_idx,
const at::Tensor mask_w_idx, const int mask_cnt,
at::Tensor top_data) {
const int output_size = mask_cnt * channels;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
bottom_data.type(), "MaskedCol2imLaucherForward", ([&] {
const scalar_t *bottom_data_ = bottom_data.data<scalar_t>();
const long *mask_h_idx_ = mask_h_idx.data<long>();
const long *mask_w_idx_ = mask_w_idx.data<long>();
scalar_t *top_data_ = top_data.data<scalar_t>();
MaskedCol2imForward<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK>>>(
output_size, bottom_data_, height, width, channels, mask_h_idx_,
mask_w_idx_, mask_cnt, top_data_);
}));
THCudaCheck(cudaGetLastError());
return 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment