Unverified Commit 1d0b43ea authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Unify onnx and JIT resize implementations (#3654)

* Make two methods as similar as possible.

* Introducing conditional fake casting.

* Change the casting mechanism.
parent 07fb8ba7
...@@ -10,36 +10,35 @@ from .roi_heads import paste_masks_in_image ...@@ -10,36 +10,35 @@ from .roi_heads import paste_masks_in_image
@torch.jit.unused @torch.jit.unused
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): def _get_shape_onnx(image):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] # type: (Tensor) -> Tensor
from torch.onnx import operators from torch.onnx import operators
im_shape = operators.shape_as_tensor(image)[-2:] return operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]
if target is None: @torch.jit.unused
return image, target def _fake_cast_onnx(v):
# type: (Tensor) -> float
if "masks" in target: # ONNX requires a tensor but here we fake its type for JIT.
mask = target["masks"] return v
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
target["masks"] = mask
return image, target
def _resize_image_and_masks(image, self_min_size, self_max_size, target): def _resize_image_and_masks(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]] # type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
if torchvision._is_tracing():
im_shape = _get_shape_onnx(image)
else:
im_shape = torch.tensor(image.shape[-2:]) im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape)) min_size = torch.min(im_shape).to(dtype=torch.float32)
scale_factor = self_min_size / min_size max_size = torch.max(im_shape).to(dtype=torch.float32)
if max_size * scale_factor > self_max_size: scale = torch.min(self_min_size / min_size, self_max_size / max_size)
scale_factor = self_max_size / max_size
if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale)
else:
scale_factor = scale.item()
image = torch.nn.functional.interpolate( image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0] align_corners=False)[0]
...@@ -145,9 +144,6 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -145,9 +144,6 @@ class GeneralizedRCNNTransform(nn.Module):
else: else:
# FIXME assume for now that testing uses the largest scale # FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1]) size = float(self.min_size[-1])
if torchvision._is_tracing():
image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
else:
image, target = _resize_image_and_masks(image, size, float(self.max_size), target) image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
if target is None: if target is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment