Unverified Commit 9bacd5c2 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add multiscale training for Keypoint R-CNN (#922)

parent cf401a70
......@@ -16,7 +16,7 @@ __all__ = [
class KeypointRCNN(FasterRCNN):
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
min_size=None, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
......@@ -37,6 +37,8 @@ class KeypointRCNN(FasterRCNN):
num_keypoints=17):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
if min_size is None:
min_size = (640, 672, 704, 736, 768, 800)
if num_classes is not None:
if keypoint_predictor is not None:
......
import random
import math
import torch
from torch import nn
......@@ -10,8 +11,10 @@ from .roi_heads import paste_masks_in_image
class GeneralizedRCNNTransform(nn.Module):
def __init__(self, min_size, max_size, image_mean, image_std):
super(GeneralizedRCNNTransform, self).__init__()
self.min_size = float(min_size)
self.max_size = float(max_size)
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size
self.image_mean = image_mean
self.image_std = image_std
......@@ -40,9 +43,14 @@ class GeneralizedRCNNTransform(nn.Module):
def resize(self, image, target):
h, w = image.shape[-2:]
min_size = min(image.shape[-2:])
max_size = max(image.shape[-2:])
scale_factor = self.min_size / min_size
min_size = float(min(image.shape[-2:]))
max_size = float(max(image.shape[-2:]))
if self.training:
size = random.choice(self.min_size)
else:
# FIXME assume for now that testing uses the largest scale
size = self.min_size[-1]
scale_factor = size / min_size
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment