Unverified Commit b40f49fd authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Remove interpolate in favor of PyTorch's implementation (#2252)

* Remove interpolate in favor of PyTorch's implementation

* Bugfix

* Bugfix
parent 98aa805e
import torch
from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign
from ..utils import load_state_dict_from_url
......@@ -253,7 +251,7 @@ class KeypointRCNNPredictor(nn.Module):
def forward(self, x):
x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate(
x = torch.nn.functional.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
)
return x
......
......@@ -5,7 +5,6 @@ import torch.nn.functional as F
from torch import nn, Tensor
from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import roi_align
......@@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height
roi_map = torch.nn.functional.interpolate(
maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0]
roi_map = F.interpolate(
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0]
w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
......@@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois):
roi_map_height = int(heights_ceil[i].item())
width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height
roi_map = torch.nn.functional.interpolate(
maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0]
roi_map = F.interpolate(
maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0]
# roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
......@@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, -1, -1))
# Resize mask
mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]
im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
......@@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
# Resize mask
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = mask[0][0]
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
......
......@@ -2,10 +2,10 @@ import random
import math
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional
from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
from .roi_heads import paste_masks_in_image
......@@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask
return image, target
......@@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target):
if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask
return image, target
......
from collections import OrderedDict
from torch.jit.annotations import Optional, List
from torch import Tensor
"""
helper class that supports empty tensors on some nn functions.
......@@ -12,10 +8,8 @@ This can be removed once https://github.com/pytorch/pytorch/issues/12013
is implemented
"""
import math
import warnings
import torch
from torchvision.ops import _new_empty_tensor
class Conv2d(torch.nn.Conv2d):
......@@ -42,51 +36,7 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
def _check_size_scale_factor(dim, size, scale_factor):
# type: (int, Optional[List[int]], Optional[float]) -> None
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if scale_factor is not None:
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)
def _output_size(dim, input, size, scale_factor):
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
assert dim == 2
_check_size_scale_factor(dim, size, scale_factor)
if size is not None:
return size
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
assert scale_factor is not None and isinstance(scale_factor, (int, float))
scale_factors = [scale_factor, scale_factor]
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
interpolate = torch.nn.functional.interpolate
# This is not in nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment