Unverified Commit 1d0b43ea authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Unify onnx and JIT resize implementations (#3654)

* Make two methods as similar as possible.

* Introducing conditional fake casting.

* Change the casting mechanism.
parent 07fb8ba7
......@@ -10,36 +10,35 @@ from .roi_heads import paste_masks_in_image
@torch.jit.unused
def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
def _get_shape_onnx(image):
# type: (Tensor) -> Tensor
from torch.onnx import operators
im_shape = operators.shape_as_tensor(image)[-2:]
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]
return operators.shape_as_tensor(image)[-2:]
if target is None:
return image, target
if "masks" in target:
mask = target["masks"]
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte()
target["masks"] = mask
return image, target
@torch.jit.unused
def _fake_cast_onnx(v):
# type: (Tensor) -> float
# ONNX requires a tensor but here we fake its type for JIT.
return v
def _resize_image_and_masks(image, self_min_size, self_max_size, target):
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
im_shape = torch.tensor(image.shape[-2:])
min_size = float(torch.min(im_shape))
max_size = float(torch.max(im_shape))
scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size
if torchvision._is_tracing():
im_shape = _get_shape_onnx(image)
else:
im_shape = torch.tensor(image.shape[-2:])
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale = torch.min(self_min_size / min_size, self_max_size / max_size)
if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale)
else:
scale_factor = scale.item()
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,
align_corners=False)[0]
......@@ -145,10 +144,7 @@ class GeneralizedRCNNTransform(nn.Module):
else:
# FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1])
if torchvision._is_tracing():
image, target = _resize_image_and_masks_onnx(image, size, float(self.max_size), target)
else:
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
image, target = _resize_image_and_masks(image, size, float(self.max_size), target)
if target is None:
return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment