Unverified Commit a0867fac authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding RandomShortestSize transform (#5610)

parent bb79470a
from typing import List, Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional, Union
import torch
import torchvision
......@@ -401,3 +401,39 @@ class FixedSizeCrop(nn.Module):
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
return img, target
class RandomShortestSize(nn.Module):
def __init__(
self,
min_size: Union[List[int], Tuple[int], int],
max_size: int,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
):
super().__init__()
self.min_size = [min_size] if isinstance(min_size, int) else list(min_size)
self.max_size = max_size
self.interpolation = interpolation
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
_, orig_height, orig_width = F.get_dimensions(image)
min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
new_width = int(orig_width * r)
new_height = int(orig_height * r)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
)
return image, target
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment