Unverified Commit 6bf920b9 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

[BC-breaking] Update KeypointRCNN weights (#1609)

* Update KeypointRCNN weights with correct file

* Fix model

* Fix
parent 99384107
......@@ -259,8 +259,11 @@ class KeypointRCNNPredictor(nn.Module):
model_urls = {
'keypointrcnn_resnet50_fpn_coco':
# legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
'keypointrcnn_resnet50_fpn_coco_legacy':
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth',
'keypointrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth',
}
......@@ -312,7 +315,10 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['keypointrcnn_resnet50_fpn_coco'],
key = 'keypointrcnn_resnet50_fpn_coco'
if pretrained == 'legacy':
key += '_legacy'
state_dict = load_state_dict_from_url(model_urls[key],
progress=progress)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment