Unverified Commit 300a9092 authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Add non-TS'able _resize_image_and_masks variant with less tensor ops (#7592)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
parent d2f7486c
......@@ -24,8 +24,8 @@ def _fake_cast_onnx(v: Tensor) -> float:
def _resize_image_and_masks(
image: Tensor,
self_min_size: float,
self_max_size: float,
self_min_size: int,
self_max_size: int,
target: Optional[Dict[str, Tensor]] = None,
fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
......@@ -40,14 +40,24 @@ def _resize_image_and_masks(
if fixed_size is not None:
size = [fixed_size[1], fixed_size[0]]
else:
if torch.jit.is_scripting() or torchvision._is_tracing():
min_size = torch.min(im_shape).to(dtype=torch.float32)
max_size = torch.max(im_shape).to(dtype=torch.float32)
scale = torch.min(self_min_size / min_size, self_max_size / max_size)
self_min_size_f = float(self_min_size)
self_max_size_f = float(self_max_size)
scale = torch.min(self_min_size_f / min_size, self_max_size_f / max_size)
if torchvision._is_tracing():
scale_factor = _fake_cast_onnx(scale)
else:
scale_factor = scale.item()
else:
# Do it the normal way
min_size = min(im_shape)
max_size = max(im_shape)
scale_factor = min(self_min_size / min_size, self_max_size / max_size)
recompute_scale_factor = True
image = torch.nn.functional.interpolate(
......@@ -159,8 +169,7 @@ class GeneralizedRCNNTransform(nn.Module):
def torch_choice(self, k: List[int]) -> int:
"""
Implements `random.choice` via torch ops, so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed.
TorchScript and we use PyTorch's RNG (not native RNG)
"""
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
return k[index]
......@@ -174,11 +183,10 @@ class GeneralizedRCNNTransform(nn.Module):
if self.training:
if self._skip_resize:
return image, target
size = float(self.torch_choice(self.min_size))
size = self.torch_choice(self.min_size)
else:
# FIXME assume for now that testing uses the largest scale
size = float(self.min_size[-1])
image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
size = self.min_size[-1]
image, target = _resize_image_and_masks(image, size, self.max_size, target, self.fixed_size)
if target is None:
return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment