Unverified Commit cc0d1beb authored by Sai Krishna's avatar Sai Krishna Committed by GitHub
Browse files

Modifying keypoint_rcnn.py for keypoint_predictor issue (#5180)



* Modifying keypoint_rcnn.py

* Update torchvision/models/detection/keypoint_rcnn.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Remove unnecessary new line
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 8265469b
...@@ -188,16 +188,18 @@ class KeypointRCNN(FasterRCNN): ...@@ -188,16 +188,18 @@ class KeypointRCNN(FasterRCNN):
keypoint_roi_pool=None, keypoint_roi_pool=None,
keypoint_head=None, keypoint_head=None,
keypoint_predictor=None, keypoint_predictor=None,
num_keypoints=17, num_keypoints=None,
): ):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
if min_size is None: if min_size is None:
min_size = (640, 672, 704, 736, 768, 800) min_size = (640, 672, 704, 736, 768, 800)
if num_classes is not None: if num_keypoints is not None:
if keypoint_predictor is not None: if keypoint_predictor is not None:
raise ValueError("num_classes should be None when keypoint_predictor is specified") raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
else:
num_keypoints = 17
out_channels = backbone.out_channels out_channels = backbone.out_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment