import torch from torch import nn from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign from .faster_rcnn import FasterRCNN from .backbone_utils import resnet_fpn_backbone __all__ = [ "KeypointRCNN", "keypointrcnn_resnet50_fpn" ] class KeypointRCNN(FasterRCNN): def __init__(self, backbone, num_classes=None, # transform parameters min_size=None, max_size=1333, image_mean=None, image_std=None, # RPN parameters rpn_anchor_generator=None, rpn_head=None, rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, box_batch_size_per_image=512, box_positive_fraction=0.25, bbox_reg_weights=None, # keypoint parameters keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, keypoint_discretization_size=56, num_keypoints=17): assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) if min_size is None: min_size = (640, 672, 704, 736, 768, 800) if num_classes is not None: if keypoint_predictor is not None: raise ValueError("num_classes should be None when keypoint_predictor is specified") out_channels = backbone.out_channels if keypoint_roi_pool is None: keypoint_roi_pool = MultiScaleRoIAlign( featmap_names=[0, 1, 2, 3], output_size=14, sampling_ratio=2) if keypoint_head is None: keypoint_layers = tuple(512 for _ in range(8)) keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers) if keypoint_predictor is None: keypoint_dim_reduced = 512 # == keypoint_layers[-1] keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) super(KeypointRCNN, self).__init__( backbone, num_classes, # transform parameters min_size, max_size, image_mean, image_std, # RPN-specific parameters rpn_anchor_generator, rpn_head, rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, rpn_nms_thresh, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, # Box parameters box_roi_pool, box_head, box_predictor, box_score_thresh, box_nms_thresh, box_detections_per_img, box_fg_iou_thresh, box_bg_iou_thresh, box_batch_size_per_image, box_positive_fraction, bbox_reg_weights) self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_head = keypoint_head self.roi_heads.keypoint_predictor = keypoint_predictor self.roi_heads.keypoint_discretization_size = keypoint_discretization_size class KeypointRCNNHeads(nn.Sequential): def __init__(self, in_channels, layers): d = [] next_feature = in_channels for l in layers: d.append(misc_nn_ops.Conv2d(next_feature, l, 3, stride=1, padding=1)) d.append(nn.ReLU(inplace=True)) next_feature = l super(KeypointRCNNHeads, self).__init__(*d) for m in self.children(): if isinstance(m, misc_nn_ops.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(m.bias, 0) class KeypointRCNNPredictor(nn.Module): def __init__(self, in_channels, num_keypoints): super(KeypointRCNNPredictor, self).__init__() input_features = in_channels deconv_kernel = 4 self.kps_score_lowres = misc_nn_ops.ConvTranspose2d( input_features, num_keypoints, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1, ) nn.init.kaiming_normal_( self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu" ) nn.init.constant_(self.kps_score_lowres.bias, 0) self.up_scale = 2 self.out_channels = num_keypoints def forward(self, x): x = self.kps_score_lowres(x) x = misc_nn_ops.interpolate( x, scale_factor=self.up_scale, mode="bilinear", align_corners=False ) return x def keypointrcnn_resnet50_fpn(pretrained=False, num_classes=2, num_keypoints=17, pretrained_backbone=True, **kwargs): backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: pass return model