import mmcv import numpy as np import pycocotools.mask as mask_util import torch import torch.nn as nn from ..utils import ConvModule from mmdet.core import mask_cross_entropy, mask_target class FCNMaskHead(nn.Module): def __init__(self, num_convs=4, roi_feat_size=14, in_channels=256, conv_kernel_size=3, conv_out_channels=256, upsample_method='deconv', upsample_ratio=2, num_classes=81, class_agnostic=False, normalize=None): super(FCNMaskHead, self).__init__() if upsample_method not in [None, 'deconv', 'nearest', 'bilinear']: raise ValueError( 'Invalid upsample method {}, accepted methods ' 'are "deconv", "nearest", "bilinear"'.format(upsample_method)) self.num_convs = num_convs self.roi_feat_size = roi_feat_size # WARN: not used and reserved self.in_channels = in_channels self.conv_kernel_size = conv_kernel_size self.conv_out_channels = conv_out_channels self.upsample_method = upsample_method self.upsample_ratio = upsample_ratio self.num_classes = num_classes self.class_agnostic = class_agnostic self.normalize = normalize self.with_bias = normalize is None self.convs = nn.ModuleList() for i in range(self.num_convs): in_channels = (self.in_channels if i == 0 else self.conv_out_channels) padding = (self.conv_kernel_size - 1) // 2 self.convs.append( ConvModule( in_channels, self.conv_out_channels, 3, padding=padding, normalize=normalize, bias=self.with_bias)) if self.upsample_method is None: self.upsample = None elif self.upsample_method == 'deconv': self.upsample = nn.ConvTranspose2d( self.conv_out_channels, self.conv_out_channels, self.upsample_ratio, stride=self.upsample_ratio) else: self.upsample = nn.Upsample( scale_factor=self.upsample_ratio, mode=self.upsample_method) out_channels = 1 if self.class_agnostic else self.num_classes self.conv_logits = nn.Conv2d(self.conv_out_channels, out_channels, 1) self.relu = nn.ReLU(inplace=True) self.debug_imgs = None def init_weights(self): for m in [self.upsample, self.conv_logits]: if m is None: continue nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='relu') nn.init.constant_(m.bias, 0) def forward(self, x): for conv in self.convs: x = conv(x) if self.upsample is not None: x = self.upsample(x) if self.upsample_method == 'deconv': x = self.relu(x) mask_pred = self.conv_logits(x) return mask_pred def get_mask_target(self, pos_proposals, pos_assigned_gt_inds, gt_masks, img_meta, rcnn_train_cfg): mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, gt_masks, img_meta, rcnn_train_cfg) return mask_targets def loss(self, mask_pred, mask_targets, labels): loss = dict() loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels) loss['loss_mask'] = loss_mask return loss def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape): """Get segmentation masks from mask_pred and bboxes Args: mask_pred (Tensor or ndarray): shape (n, #class+1, h, w). For single-scale testing, mask_pred is the direct output of model, whose type is Tensor, while for multi-scale testing, it will be converted to numpy array outside of this method. det_bboxes (Tensor): shape (n, 4/5) det_labels (Tensor): shape (n, ) img_shape (Tensor): shape (3, ) rcnn_test_cfg (dict): rcnn testing config ori_shape: original image size Returns: list[list]: encoded masks """ if isinstance(mask_pred, torch.Tensor): mask_pred = mask_pred.sigmoid().cpu().numpy() assert isinstance(mask_pred, np.ndarray) cls_segms = [[] for _ in range(self.num_classes - 1)] mask_size = mask_pred.shape[-1] bboxes = det_bboxes.cpu().numpy()[:, :4] labels = det_labels.cpu().numpy() + 1 img_h = ori_shape[0] img_w = ori_shape[1] scale = (mask_size + 2.0) / mask_size bboxes = np.round(self._bbox_scaling(bboxes, scale)).astype(np.int32) padded_mask = np.zeros( (mask_size + 2, mask_size + 2), dtype=np.float32) for i in range(bboxes.shape[0]): bbox = bboxes[i, :].astype(int) label = labels[i] w = bbox[2] - bbox[0] + 1 h = bbox[3] - bbox[1] + 1 w = max(w, 1) h = max(h, 1) if not self.class_agnostic: padded_mask[1:-1, 1:-1] = mask_pred[i, label, :, :] else: padded_mask[1:-1, 1:-1] = mask_pred[i, 0, :, :] mask = mmcv.imresize(padded_mask, (w, h)) mask = np.array( mask > rcnn_test_cfg.mask_thr_binary, dtype=np.uint8) im_mask = np.zeros((img_h, img_w), dtype=np.uint8) x0 = max(bbox[0], 0) x1 = min(bbox[2] + 1, img_w) y0 = max(bbox[1], 0) y1 = min(bbox[3] + 1, img_h) im_mask[y0:y1, x0:x1] = mask[(y0 - bbox[1]):(y1 - bbox[1]), ( x0 - bbox[0]):(x1 - bbox[0])] rle = mask_util.encode( np.array(im_mask[:, :, np.newaxis], order='F'))[0] cls_segms[label - 1].append(rle) return cls_segms def _bbox_scaling(self, bboxes, scale, clip_shape=None): """Scaling bboxes and clip the boundary(optional) Args: bboxes(ndarray): shape(..., 4) scale(float): scaling factor clip(None or tuple): (h, w) Returns: ndarray: scaled bboxes """ if float(scale) == 1.0: scaled_bboxes = bboxes.copy() else: w = bboxes[..., 2] - bboxes[..., 0] + 1 h = bboxes[..., 3] - bboxes[..., 1] + 1 dw = (w * (scale - 1)) * 0.5 dh = (h * (scale - 1)) * 0.5 scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1) if clip_shape is not None: return bbox_clip(scaled_bboxes, clip_shape) else: return scaled_bboxes