from __future__ import division import torch import torch.nn as nn from mmdet import ops class SingleLevelRoI(nn.Module): """Extract RoI features from a single level feature map. Each RoI is mapped to a level according to its scale.""" def __init__(self, roi_layer, out_channels, featmap_strides, finest_scale=56): super(SingleLevelRoI, self).__init__() self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) self.out_channels = out_channels self.featmap_strides = featmap_strides self.finest_scale = finest_scale @property def num_inputs(self): return len(self.featmap_strides) def init_weights(self): pass def build_roi_layers(self, layer_cfg, featmap_strides): cfg = layer_cfg.copy() layer_type = cfg.pop('type') assert hasattr(ops, layer_type) layer_cls = getattr(ops, layer_type) roi_layers = nn.ModuleList( [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides]) return roi_layers def map_roi_levels(self, rois, num_levels): """Map rois to corresponding feature levels (0-based) by scales. scale < finest_scale: level 0 finest_scale <= scale < finest_scale * 2: level 1 finest_scale * 2 <= scale < finest_scale * 4: level 2 scale >= finest_scale * 4: level 3 """ scale = torch.sqrt( (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1)) target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() return target_lvls def forward(self, feats, rois): """Extract roi features with the roi layer. If multiple feature levels are used, then rois are mapped to corresponding levels according to their scales. """ if len(feats) == 1: return self.roi_layers[0](feats[0], rois) out_size = self.roi_layers[0].out_size num_levels = len(feats) target_lvls = self.map_roi_levels(rois, num_levels) roi_feats = torch.cuda.FloatTensor(rois.size()[0], self.out_channels, out_size, out_size).fill_(0) for i in range(num_levels): inds = target_lvls == i if inds.any(): rois_ = rois[inds, :] roi_feats_t = self.roi_layers[i](feats[i], rois_) roi_feats[inds] += roi_feats_t return roi_feats