Unverified Commit 300a9092 authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Add non-TS'able _resize_image_and_masks variant with less tensor ops (#7592)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
parent d2f7486c
...@@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float: ...@@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float:
def _resize_image_and_masks( def _resize_image_and_masks(
image: Tensor, image: Tensor,
self_min_size: float, self_min_size: int,
self_max_size: float, self_max_size: int,
target: Optional[Dict[str, Tensor]] = None, target: Optional[Dict[str, Tensor]] = None,
fixed_size: Optional[Tuple[int, int]] = None, fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
...@@ -40,14 +40,24 @@ def _resize_image_and_masks( ...@@ -40,14 +40,24 @@ def _resize_image_and_masks(
if fixed_size is not None: if fixed_size is not None:
size = [fixed_size[1], fixed_size[0]] size = [fixed_size[1], fixed_size[0]]
else: else:
if torch.jit.is_scripting() or torchvision._is_tracing():
min_size = torch.min(im_shape).to(dtype=torch.float32) min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32) max_size = torch.max(im_shape).to(dtype=torch.float32)
scale = torch.min(self_min_size / min_size, self_max_size / max_size) self_min_size_f = float(self_min_size)
self_max_size_f = float(self_max_size)
scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
if torchvision._is_tracing(): if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale) scale_factor = _fake_cast_onnx(scale)
else: else:
scale_factor = scale.item() scale_factor = scale.item()
else:
# Do it the normal way
min_size = min(im_shape)
max_size = max(im_shape)
scale_factor = min(self_min_size / min_size, self_max_size / max_size)
recompute_scale_factor = True recompute_scale_factor = True
image = torch.nn.functional.interpolate( image = torch.nn.functional.interpolate(
...@@ -159,8 +169,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -159,8 +169,7 @@ class GeneralizedRCNNTransform(nn.Module):
def torch_choice(self, k: List[int]) -> int: def torch_choice(self, k: List[int]) -> int:
""" """
Implements `random.choice` via torch ops, so it can be compiled with Implements `random.choice` via torch ops, so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 TorchScript and we use PyTorch's RNG (not native RNG)
is fixed.
""" """
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
return k[index] return k[index]
...@@ -174,11 +183,10 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -174,11 +183,10 @@ class GeneralizedRCNNTransform(nn.Module):
if self.training: if self.training:
if self._skip_resize: if self._skip_resize:
return image, target return image, target
size = float(self.torch_choice(self.min_size)) size = self.torch_choice(self.min_size)
else: else:
# FIXME assume for now that testing uses the largest scale size = self.min_size[-1]
size = float(self.min_size[-1]) image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
if target is None: if target is None:
return image, target return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment