import math import numbers import warnings from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union import PIL.Image import torch from torchvision import datapoints, transforms as _transforms from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation from ._transform import _RandomApplyTransform from ._utils import ( _check_padding_arg, _check_padding_mode_arg, _check_sequence_input, _setup_angle, _setup_fill_arg, _setup_float_or_seq, _setup_size, ) from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_spatial_size class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.horizontal_flip(inpt) class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.vertical_flip(inpt) class Resize(Transform): _v1_transform_cls = _transforms.Resize def __init__( self, size: Union[int, Sequence[int]], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() if isinstance(size, int): size = [size] elif isinstance(size, (list, tuple)) and len(size) in {1, 2}: size = list(size) else: raise ValueError( f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead." ) self.size = size self.interpolation = _check_interpolation(interpolation) self.max_size = max_size self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resize( inpt, self.size, interpolation=self.interpolation, max_size=self.max_size, antialias=self.antialias, ) class CenterCrop(Transform): _v1_transform_cls = _transforms.CenterCrop def __init__(self, size: Union[int, Sequence[int]]): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.center_crop(inpt, output_size=self.size) class RandomResizedCrop(Transform): _v1_transform_cls = _transforms.RandomResizedCrop def __init__( self, size: Union[int, Sequence[int]], scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") if not isinstance(scale, Sequence): raise TypeError("Scale should be a sequence") scale = cast(Tuple[float, float], scale) if not isinstance(ratio, Sequence): raise TypeError("Ratio should be a sequence") ratio = cast(Tuple[float, float], ratio) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") self.scale = scale self.ratio = ratio self.interpolation = _check_interpolation(interpolation) self.antialias = antialias self._log_ratio = torch.log(torch.tensor(self.ratio)) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) area = height * width log_ratio = self._log_ratio for _ in range(10): target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() aspect_ratio = torch.exp( torch.empty(1).uniform_( log_ratio[0], # type: ignore[arg-type] log_ratio[1], # type: ignore[arg-type] ) ).item() w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: i = torch.randint(0, height - h + 1, size=(1,)).item() j = torch.randint(0, width - w + 1, size=(1,)).item() break else: # Fallback to central crop in_ratio = float(width) / float(height) if in_ratio < min(self.ratio): w = width h = int(round(w / min(self.ratio))) elif in_ratio > max(self.ratio): h = height w = int(round(h * max(self.ratio))) else: # whole image w = width h = height i = (height - h) // 2 j = (width - w) // 2 return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resized_crop( inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] class FiveCrop(Transform): """ Example: >>> class BatchMultiCrop(transforms.Transform): ... def forward(self, sample: Tuple[Tuple[Union[datapoints.Image, datapoints.Video], ...], int]): ... images_or_videos, labels = sample ... batch_size = len(images_or_videos) ... image_or_video = images_or_videos[0] ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) ... labels = torch.full((batch_size,), label, device=images_or_videos.device) ... return images_or_videos, labels ... >>> image = datapoints.Image(torch.rand(3, 256, 256)) >>> label = 3 >>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()]) >>> images, labels = transform(image, label) >>> images.shape torch.Size([5, 3, 224, 224]) >>> labels tensor([3, 3, 3, 3, 3]) """ _v1_transform_cls = _transforms.FiveCrop _transformed_types = ( datapoints.Image, PIL.Image.Image, is_simple_tensor, datapoints.Video, ) def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform( self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any] ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: return F.five_crop(inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") class TenCrop(Transform): """ See :class:`~torchvision.transforms.v2.FiveCrop` for an example. """ _v1_transform_cls = _transforms.TenCrop _transformed_types = ( datapoints.Image, PIL.Image.Image, is_simple_tensor, datapoints.Video, ) def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBox, datapoints.Mask): raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Tuple[ ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) class Pad(Transform): _v1_transform_cls = _transforms.Pad def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") return params def __init__( self, padding: Union[int, Sequence[int]], fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() _check_padding_arg(padding) _check_padding_mode_arg(padding_mode) # This cast does Sequence[int] -> List[int] and is required to make mypy happy if not isinstance(padding, int): padding = list(padding) self.padding = padding self.fill = fill self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): def __init__( self, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, side_range: Sequence[float] = (1.0, 4.0), p: float = 0.5, ) -> None: super().__init__(p=p) self.fill = fill self._fill = _setup_fill_arg(fill) _check_sequence_input(side_range, "side_range", req_sizes=(2,)) self.side_range = side_range if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid canvas side range provided {side_range}.") def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_spatial_size(flat_inputs) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) canvas_width = int(orig_w * r) canvas_height = int(orig_h * r) r = torch.rand(2) left = int((canvas_width - orig_w) * r[0]) top = int((canvas_height - orig_h) * r[1]) right = canvas_width - (left + orig_w) bottom = canvas_height - (top + orig_h) padding = [left, top, right, bottom] return dict(padding=padding) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] return F.pad(inpt, **params, fill=fill) class RandomRotation(Transform): _v1_transform_cls = _transforms.RandomRotation def __init__( self, degrees: Union[numbers.Number, Sequence], interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, expand: bool = False, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.interpolation = _check_interpolation(interpolation) self.expand = expand self.fill = fill self._fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] return F.rotate( inpt, **params, interpolation=self.interpolation, expand=self.expand, center=self.center, fill=fill, ) class RandomAffine(Transform): _v1_transform_cls = _transforms.RandomAffine def __init__( self, degrees: Union[numbers.Number, Sequence], translate: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, center: Optional[List[float]] = None, ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: _check_sequence_input(translate, "translate", req_sizes=(2,)) for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: _check_sequence_input(scale, "scale", req_sizes=(2,)) for s in scale: if s <= 0: raise ValueError("scale values should be positive") self.scale = scale if shear is not None: self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) else: self.shear = shear self.interpolation = _check_interpolation(interpolation) self.fill = fill self._fill = _setup_fill_arg(fill) if center is not None: _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() if self.translate is not None: max_dx = float(self.translate[0] * width) max_dy = float(self.translate[1] * height) tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) translate = (tx, ty) else: translate = (0, 0) if self.scale is not None: scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() else: scale = 1.0 shear_x = shear_y = 0.0 if self.shear is not None: shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item() if len(self.shear) == 4: shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item() shear = (shear_x, shear_y) return dict(angle=angle, translate=translate, scale=scale, shear=shear) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] return F.affine( inpt, **params, interpolation=self.interpolation, fill=fill, center=self.center, ) class RandomCrop(Transform): _v1_transform_cls = _transforms.RandomCrop def _extract_params_for_v1_transform(self) -> Dict[str, Any]: params = super()._extract_params_for_v1_transform() if not (params["fill"] is None or isinstance(params["fill"], (int, float))): raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.") padding = self.padding if padding is not None: pad_left, pad_right, pad_top, pad_bottom = padding padding = [pad_left, pad_top, pad_right, pad_bottom] params["padding"] = padding return params def __init__( self, size: Union[int, Sequence[int]], padding: Optional[Union[int, Sequence[int]]] = None, pad_if_needed: bool = False, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") if pad_if_needed or padding is not None: if padding is not None: _check_padding_arg(padding) _check_padding_mode_arg(padding_mode) self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] self.pad_if_needed = pad_if_needed self.fill = fill self._fill = _setup_fill_arg(fill) self.padding_mode = padding_mode def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: padded_height, padded_width = query_spatial_size(flat_inputs) if self.padding is not None: pad_left, pad_right, pad_top, pad_bottom = self.padding padded_height += pad_top + pad_bottom padded_width += pad_left + pad_right else: pad_left = pad_right = pad_top = pad_bottom = 0 cropped_height, cropped_width = self.size if self.pad_if_needed: if padded_height < cropped_height: diff = cropped_height - padded_height pad_top += diff pad_bottom += diff padded_height += 2 * diff if padded_width < cropped_width: diff = cropped_width - padded_width pad_left += diff pad_right += diff padded_width += 2 * diff if padded_height < cropped_height or padded_width < cropped_width: raise ValueError( f"Required crop size {(cropped_height, cropped_width)} is larger than " f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}." ) # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad` padding = [pad_left, pad_top, pad_right, pad_bottom] needs_pad = any(padding) needs_vert_crop, top = ( (True, int(torch.randint(0, padded_height - cropped_height + 1, size=()))) if padded_height > cropped_height else (False, 0) ) needs_horz_crop, left = ( (True, int(torch.randint(0, padded_width - cropped_width + 1, size=()))) if padded_width > cropped_width else (False, 0) ) return dict( needs_crop=needs_vert_crop or needs_horz_crop, top=top, left=left, height=cropped_height, width=cropped_width, needs_pad=needs_pad, padding=padding, ) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = self._fill[type(inpt)] inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) return inpt class RandomPerspective(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomPerspective def __init__( self, distortion_scale: float = 0.5, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, p: float = 0.5, ) -> None: super().__init__(p=p) if not (0 <= distortion_scale <= 1): raise ValueError("Argument distortion_scale value should be between 0 and 1") self.distortion_scale = distortion_scale self.interpolation = _check_interpolation(interpolation) self.fill = fill self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: height, width = query_spatial_size(flat_inputs) distortion_scale = self.distortion_scale half_height = height // 2 half_width = width // 2 bound_height = int(distortion_scale * half_height) + 1 bound_width = int(distortion_scale * half_width) + 1 topleft = [ int(torch.randint(0, bound_width, size=(1,))), int(torch.randint(0, bound_height, size=(1,))), ] topright = [ int(torch.randint(width - bound_width, width, size=(1,))), int(torch.randint(0, bound_height, size=(1,))), ] botright = [ int(torch.randint(width - bound_width, width, size=(1,))), int(torch.randint(height - bound_height, height, size=(1,))), ] botleft = [ int(torch.randint(0, bound_width, size=(1,))), int(torch.randint(height - bound_height, height, size=(1,))), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) return dict(coefficients=perspective_coeffs) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] return F.perspective( inpt, None, None, fill=fill, interpolation=self.interpolation, **params, ) class ElasticTransform(Transform): _v1_transform_cls = _transforms.ElasticTransform def __init__( self, alpha: Union[float, Sequence[float]] = 50.0, sigma: Union[float, Sequence[float]] = 5.0, fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, ) -> None: super().__init__() self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.interpolation = _check_interpolation(interpolation) self.fill = fill self._fill = _setup_fill_arg(fill) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = list(query_spatial_size(flat_inputs)) dx = torch.rand([1, 1] + size) * 2 - 1 if self.sigma[0] > 0.0: kx = int(8 * self.sigma[0] + 1) # if kernel size is even we have to make it odd if kx % 2 == 0: kx += 1 dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] dy = torch.rand([1, 1] + size) * 2 - 1 if self.sigma[1] > 0.0: ky = int(8 * self.sigma[1] + 1) # if kernel size is even we have to make it odd if ky % 2 == 0: ky += 1 dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) dy = dy * self.alpha[1] / size[1] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = self._fill[type(inpt)] return F.elastic( inpt, **params, fill=fill, interpolation=self.interpolation, ) class RandomIoUCrop(Transform): def __init__( self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40, ): super().__init__() # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 self.min_scale = min_scale self.max_scale = max_scale self.min_aspect_ratio = min_aspect_ratio self.max_aspect_ratio = max_aspect_ratio if sampler_options is None: sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] self.options = sampler_options self.trials = trials def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( has_all(flat_inputs, datapoints.BoundingBox) and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) ): raise TypeError( f"{type(self).__name__}() requires input sample to contain tensor or PIL images " "and bounding boxes. Sample can also contain masks." ) def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_spatial_size(flat_inputs) bboxes = query_bounding_box(flat_inputs) while True: # sample an option idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) min_jaccard_overlap = self.options[idx] if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option return dict() for _ in range(self.trials): # check the aspect ratio limitations r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) new_w = int(orig_w * r[0]) new_h = int(orig_h * r[1]) aspect_ratio = new_w / new_h if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): continue # check for 0 area crops r = torch.rand(2) left = int((orig_w - new_w) * r[0]) top = int((orig_h - new_h) * r[1]) right = left + new_w bottom = top + new_h if left == right or top == bottom: continue # check for any valid boxes with centers within the crop area xyxy_bboxes = F.convert_format_bounding_box( bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY ) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) if not is_within_crop_area.any(): continue # check at least 1 box with jaccard limitations xyxy_bboxes = xyxy_bboxes[is_within_crop_area] ious = box_iou( xyxy_bboxes, torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), ) if ious.max() < min_jaccard_overlap: continue return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) if isinstance(output, datapoints.BoundingBox): # We "mark" the invalid boxes as degenreate, and they can be # removed by a later call to SanitizeBoundingBoxes() output[~params["is_within_crop_area"]] = 0 return output class ScaleJitter(Transform): def __init__( self, target_size: Tuple[int, int], scale_range: Tuple[float, float] = (0.1, 2.0), interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() self.target_size = target_size self.scale_range = scale_range self.interpolation = _check_interpolation(interpolation) self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_spatial_size(flat_inputs) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale new_width = int(orig_width * r) new_height = int(orig_height * r) return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) class RandomShortestSize(Transform): def __init__( self, min_size: Union[List[int], Tuple[int], int], max_size: Optional[int] = None, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ): super().__init__() self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) self.max_size = max_size self.interpolation = _check_interpolation(interpolation) self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_spatial_size(flat_inputs) min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] r = min_size / min(orig_height, orig_width) if self.max_size is not None: r = min(r, self.max_size / max(orig_height, orig_width)) new_width = int(orig_width * r) new_height = int(orig_height * r) return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) class RandomResize(Transform): def __init__( self, min_size: int, max_size: int, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> None: super().__init__() self.min_size = min_size self.max_size = max_size self.interpolation = _check_interpolation(interpolation) self.antialias = antialias def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: size = int(torch.randint(self.min_size, self.max_size, ())) return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)