Unverified Commit 5608b06a authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add COCO pre-trained weights for Faster R-CNN R-50 FPN (#925)

* Add COCO pre-trained weights for Faster R-CNN R-50 FPN

* Add weights for Mask R-CNN and Keypoint R-CNN
parent cffa4e00
...@@ -7,6 +7,8 @@ import torch.nn.functional as F ...@@ -7,6 +7,8 @@ import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url
from .generalized_rcnn import GeneralizedRCNN from .generalized_rcnn import GeneralizedRCNN
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads from .roi_heads import RoIHeads
...@@ -149,9 +151,21 @@ class FastRCNNPredictor(nn.Module): ...@@ -149,9 +151,21 @@ class FastRCNNPredictor(nn.Module):
return scores, bbox_deltas return scores, bbox_deltas
def fasterrcnn_resnet50_fpn(pretrained=False, num_classes=81, pretrained_backbone=True, **kwargs): model_urls = {
'fasterrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
}
def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
model = FasterRCNN(backbone, num_classes, **kwargs) model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
pass state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
return model return model
...@@ -4,6 +4,8 @@ from torch import nn ...@@ -4,6 +4,8 @@ from torch import nn
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone from .backbone_utils import resnet_fpn_backbone
...@@ -127,10 +129,22 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -127,10 +129,22 @@ class KeypointRCNNPredictor(nn.Module):
return x return x
def keypointrcnn_resnet50_fpn(pretrained=False, num_classes=2, num_keypoints=17, model_urls = {
'keypointrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth',
}
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=2, num_keypoints=17,
pretrained_backbone=True, **kwargs): pretrained_backbone=True, **kwargs):
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained: if pretrained:
pass state_dict = load_state_dict_from_url(model_urls['keypointrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
return model return model
...@@ -7,6 +7,8 @@ import torch.nn.functional as F ...@@ -7,6 +7,8 @@ import torch.nn.functional as F
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone from .backbone_utils import resnet_fpn_backbone
...@@ -125,9 +127,21 @@ class MaskRCNNPredictor(nn.Sequential): ...@@ -125,9 +127,21 @@ class MaskRCNNPredictor(nn.Sequential):
# nn.init.constant_(param, 0) # nn.init.constant_(param, 0)
def maskrcnn_resnet50_fpn(pretrained=False, num_classes=81, pretrained_backbone=True, **kwargs): model_urls = {
'maskrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth',
}
def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, **kwargs):
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
model = MaskRCNN(backbone, num_classes, **kwargs) model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
pass state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
progress=progress)
model.load_state_dict(state_dict)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment