from collections import OrderedDict import torch from torch import nn import torch.nn.functional as F from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign from .faster_rcnn import FasterRCNN from .backbone_utils import resnet_fpn_backbone __all__ = [ "MaskRCNN", "maskrcnn_resnet50_fpn", ] class MaskRCNN(FasterRCNN): def __init__(self, backbone, num_classes=None, # transform parameters min_size=800, max_size=1333, image_mean=None, image_std=None, # RPN parameters rpn_anchor_generator=None, rpn_head=None, rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, box_batch_size_per_image=512, box_positive_fraction=0.25, bbox_reg_weights=None, # Mask parameters mask_roi_pool=None, mask_head=None, mask_predictor=None, mask_discretization_size=28): assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) if num_classes is not None: if mask_predictor is not None: raise ValueError("num_classes should be None when mask_predictor is specified") out_channels = backbone.out_channels if mask_roi_pool is None: mask_roi_pool = MultiScaleRoIAlign( featmap_names=[0, 1, 2, 3], output_size=14, sampling_ratio=2) if mask_head is None: mask_layers = (256, 256, 256, 256) mask_dilation = 1 mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation) if mask_predictor is None: mask_dim_reduced = 256 # == mask_layers[-1] mask_predictor = MaskRCNNPredictor(out_channels, mask_dim_reduced, num_classes) super(MaskRCNN, self).__init__( backbone, num_classes, # transform parameters min_size, max_size, image_mean, image_std, # RPN-specific parameters rpn_anchor_generator, rpn_head, rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, rpn_nms_thresh, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, # Box parameters box_roi_pool, box_head, box_predictor, box_score_thresh, box_nms_thresh, box_detections_per_img, box_fg_iou_thresh, box_bg_iou_thresh, box_batch_size_per_image, box_positive_fraction, bbox_reg_weights) self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_head = mask_head self.roi_heads.mask_predictor = mask_predictor self.roi_heads.mask_discretization_size = mask_discretization_size class MaskRCNNHeads(nn.Sequential): def __init__(self, in_channels, layers, dilation): """ Arguments: num_classes (int): number of output classes input_size (int): number of channels of the input once it's flattened representation_size (int): size of the intermediate representation """ d = OrderedDict() next_feature = in_channels for layer_idx, layer_features in enumerate(layers, 1): d["mask_fcn{}".format(layer_idx)] = misc_nn_ops.Conv2d( next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation) d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True) next_feature = layer_features super(MaskRCNNHeads, self).__init__(d) for name, param in self.named_parameters(): if "weight" in name: nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") # elif "bias" in name: # nn.init.constant_(param, 0) class MaskRCNNPredictor(nn.Sequential): def __init__(self, in_channels, dim_reduced, num_classes): super(MaskRCNNPredictor, self).__init__(OrderedDict([ ("conv5_mask", misc_nn_ops.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), ("relu", nn.ReLU(inplace=True)), ("mask_fcn_logits", misc_nn_ops.Conv2d(dim_reduced, num_classes, 1, 1, 0)), ])) for name, param in self.named_parameters(): if "weight" in name: nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") # elif "bias" in name: # nn.init.constant_(param, 0) def maskrcnn_resnet50_fpn(pretrained=False, num_classes=81, pretrained_backbone=True, **kwargs): backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: pass return model