Unverified Commit 9bacd5c2 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add multiscale training for Keypoint R-CNN (#922)

parent cf401a70
...@@ -16,7 +16,7 @@ __all__ = [ ...@@ -16,7 +16,7 @@ __all__ = [
class KeypointRCNN(FasterRCNN): class KeypointRCNN(FasterRCNN):
def __init__(self, backbone, num_classes=None, def __init__(self, backbone, num_classes=None,
# transform parameters # transform parameters
min_size=800, max_size=1333, min_size=None, max_size=1333,
image_mean=None, image_std=None, image_mean=None, image_std=None,
# RPN parameters # RPN parameters
rpn_anchor_generator=None, rpn_head=None, rpn_anchor_generator=None, rpn_head=None,
...@@ -37,6 +37,8 @@ class KeypointRCNN(FasterRCNN): ...@@ -37,6 +37,8 @@ class KeypointRCNN(FasterRCNN):
num_keypoints=17): num_keypoints=17):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
if min_size is None:
min_size = (640, 672, 704, 736, 768, 800)
if num_classes is not None: if num_classes is not None:
if keypoint_predictor is not None: if keypoint_predictor is not None:
......
import random
import math import math
import torch import torch
from torch import nn from torch import nn
...@@ -10,8 +11,10 @@ from .roi_heads import paste_masks_in_image ...@@ -10,8 +11,10 @@ from .roi_heads import paste_masks_in_image
class GeneralizedRCNNTransform(nn.Module): class GeneralizedRCNNTransform(nn.Module):
def __init__(self, min_size, max_size, image_mean, image_std): def __init__(self, min_size, max_size, image_mean, image_std):
super(GeneralizedRCNNTransform, self).__init__() super(GeneralizedRCNNTransform, self).__init__()
self.min_size = float(min_size) if not isinstance(min_size, (list, tuple)):
self.max_size = float(max_size) min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
...@@ -40,9 +43,14 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -40,9 +43,14 @@ class GeneralizedRCNNTransform(nn.Module):
def resize(self, image, target): def resize(self, image, target):
h, w = image.shape[-2:] h, w = image.shape[-2:]
min_size = min(image.shape[-2:]) min_size = float(min(image.shape[-2:]))
max_size = max(image.shape[-2:]) max_size = float(max(image.shape[-2:]))
scale_factor = self.min_size / min_size if self.training:
size = random.choice(self.min_size)
else:
# FIXME assume for now that testing uses the largest scale
size = self.min_size[-1]
scale_factor = size / min_size
if max_size * scale_factor > self.max_size: if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate( image = torch.nn.functional.interpolate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment