Unverified Commit b40f49fd authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove interpolate in favor of PyTorch's implementation (#2252)

* Remove interpolate in favor of PyTorch's implementation

* Bugfix

* Bugfix
parent 98aa805e
import torch import torch
from torch import nn from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url from ..utils import load_state_dict_from_url
...@@ -253,7 +251,7 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -253,7 +251,7 @@ class KeypointRCNNPredictor(nn.Module):
def forward(self, x): def forward(self, x):
x = self.kps_score_lowres(x) x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate( x = torch.nn.functional.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
) )
return x return x
......
...@@ -5,7 +5,6 @@ import torch.nn.functional as F ...@@ -5,7 +5,6 @@ import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import roi_align from torchvision.ops import roi_align
...@@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, ...@@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
width_correction = widths_i / roi_map_width width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height height_correction = heights_i / roi_map_height
roi_map = torch.nn.functional.interpolate( roi_map = F.interpolate(
maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0] maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0]
w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
...@@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois):
roi_map_height = int(heights_ceil[i].item()) roi_map_height = int(heights_ceil[i].item())
width_correction = widths[i] / roi_map_width width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height height_correction = heights[i] / roi_map_height
roi_map = torch.nn.functional.interpolate( roi_map = F.interpolate(
maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0] maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0]
# roi_map_probs = scores_to_probs(roi_map.copy()) # roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2] w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
...@@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): ...@@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, -1, -1)) mask = mask.expand((1, 1, -1, -1))
# Resize mask # Resize mask
mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0] mask = mask[0][0]
im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
...@@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): ...@@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, mask.size(0), mask.size(1))) mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
# Resize mask # Resize mask
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = mask[0][0] mask = mask[0][0]
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
......
...@@ -2,10 +2,10 @@ import random ...@@ -2,10 +2,10 @@ import random
import math import math
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F
import torchvision import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional from torch.jit.annotations import List, Tuple, Dict, Optional
from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList from .image_list import ImageList
from .roi_heads import paste_masks_in_image from .roi_heads import paste_masks_in_image
...@@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): ...@@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
if "masks" in target: if "masks" in target:
mask = target["masks"] mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask target["masks"] = mask
return image, target return image, target
...@@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target): ...@@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target):
if "masks" in target: if "masks" in target:
mask = target["masks"] mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask target["masks"] = mask
return image, target return image, target
......
from collections import OrderedDict
from torch.jit.annotations import Optional, List
from torch import Tensor
""" """
helper class that supports empty tensors on some nn functions. helper class that supports empty tensors on some nn functions.
...@@ -12,10 +8,8 @@ This can be removed once https://github.com/pytorch/pytorch/issues/12013 ...@@ -12,10 +8,8 @@ This can be removed once https://github.com/pytorch/pytorch/issues/12013
is implemented is implemented
""" """
import math
import warnings import warnings
import torch import torch
from torchvision.ops import _new_empty_tensor
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
...@@ -42,51 +36,7 @@ class BatchNorm2d(torch.nn.BatchNorm2d): ...@@ -42,51 +36,7 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
def _check_size_scale_factor(dim, size, scale_factor): interpolate = torch.nn.functional.interpolate
# type: (int, Optional[List[int]], Optional[float]) -> None
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if scale_factor is not None:
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def _output_size(dim, input, size, scale_factor):
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
assert dim == 2
_check_size_scale_factor(dim, size, scale_factor)
if size is not None:
return size
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
assert scale_factor is not None and isinstance(scale_factor, (int, float))
scale_factors = [scale_factor, scale_factor]
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
# This is not in nn # This is not in nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment