"launcher/vscode:/vscode.git/clone" did not exist on "e6d3eb5d5d257fb20caf1f86fb3cd2ef530fb555"
Unverified Commit 312c3d32 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove spatial_size (#7734)

parent bdf16222
...@@ -9,7 +9,7 @@ from torchvision import datapoints, transforms as _transforms ...@@ -9,7 +9,7 @@ from torchvision import datapoints, transforms as _transforms
from torchvision.transforms import _functional_tensor as _FT from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_spatial_size from torchvision.transforms.v2.functional._meta import get_size
from ._utils import _setup_fill_arg from ._utils import _setup_fill_arg
from .utils import check_type, is_simple_tensor from .utils import check_type, is_simple_tensor
...@@ -312,7 +312,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -312,7 +312,7 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video) height, width = get_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -403,7 +403,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -403,7 +403,7 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video) height, width = get_size(image_or_video)
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -474,7 +474,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -474,7 +474,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video) height, width = get_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -568,7 +568,7 @@ class AugMix(_AutoAugmentBase): ...@@ -568,7 +568,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(orig_image_or_video) height, width = get_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor): if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video image_or_video = orig_image_or_video
......
...@@ -22,7 +22,7 @@ from ._utils import ( ...@@ -22,7 +22,7 @@ from ._utils import (
_setup_float_or_seq, _setup_float_or_seq,
_setup_size, _setup_size,
) )
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_spatial_size from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -267,7 +267,7 @@ class RandomResizedCrop(Transform): ...@@ -267,7 +267,7 @@ class RandomResizedCrop(Transform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs) height, width = query_size(flat_inputs)
area = height * width area = height * width
log_ratio = self._log_ratio log_ratio = self._log_ratio
...@@ -558,7 +558,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -558,7 +558,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs) orig_h, orig_w = query_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -735,7 +735,7 @@ class RandomAffine(Transform): ...@@ -735,7 +735,7 @@ class RandomAffine(Transform):
self.center = center self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs) height, width = query_size(flat_inputs)
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item() angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None: if self.translate is not None:
...@@ -859,7 +859,7 @@ class RandomCrop(Transform): ...@@ -859,7 +859,7 @@ class RandomCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(flat_inputs) padded_height, padded_width = query_size(flat_inputs)
if self.padding is not None: if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding pad_left, pad_right, pad_top, pad_bottom = self.padding
...@@ -972,7 +972,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -972,7 +972,7 @@ class RandomPerspective(_RandomApplyTransform):
self._fill = _setup_fill_arg(fill) self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs) height, width = query_size(flat_inputs)
distortion_scale = self.distortion_scale distortion_scale = self.distortion_scale
...@@ -1072,7 +1072,7 @@ class ElasticTransform(Transform): ...@@ -1072,7 +1072,7 @@ class ElasticTransform(Transform):
self._fill = _setup_fill_arg(fill) self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs)) size = list(query_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1 dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0: if self.sigma[0] > 0.0:
...@@ -1164,7 +1164,7 @@ class RandomIoUCrop(Transform): ...@@ -1164,7 +1164,7 @@ class RandomIoUCrop(Transform):
) )
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs) orig_h, orig_w = query_size(flat_inputs)
bboxes = query_bounding_boxes(flat_inputs) bboxes = query_bounding_boxes(flat_inputs)
while True: while True:
...@@ -1282,7 +1282,7 @@ class ScaleJitter(Transform): ...@@ -1282,7 +1282,7 @@ class ScaleJitter(Transform):
self.antialias = antialias self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs) orig_height, orig_width = query_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
...@@ -1347,7 +1347,7 @@ class RandomShortestSize(Transform): ...@@ -1347,7 +1347,7 @@ class RandomShortestSize(Transform):
self.antialias = antialias self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs) orig_height, orig_width = query_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min_size / min(orig_height, orig_width) r = min_size / min(orig_height, orig_width)
......
...@@ -30,7 +30,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -30,7 +30,7 @@ class ConvertBoundingBoxFormat(Transform):
class ClampBoundingBoxes(Transform): class ClampBoundingBoxes(Transform):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions. """[BETA] Clamp bounding boxes to their corresponding image dimensions.
The clamping is done according to the bounding boxes' ``spatial_size`` meta-data. The clamping is done according to the bounding boxes' ``canvas_size`` meta-data.
.. v2betastatus:: ClampBoundingBoxes transform .. v2betastatus:: ClampBoundingBoxes transform
......
...@@ -408,7 +408,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -408,7 +408,7 @@ class SanitizeBoundingBoxes(Transform):
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All # TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen? # transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size image_h, image_w = boxes.canvas_size
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
......
...@@ -15,12 +15,12 @@ from ._meta import ( ...@@ -15,12 +15,12 @@ from ._meta import (
get_num_channels_image_pil, get_num_channels_image_pil,
get_num_channels_video, get_num_channels_video,
get_num_channels, get_num_channels,
get_spatial_size_bounding_boxes, get_size_bounding_boxes,
get_spatial_size_image_tensor, get_size_image_tensor,
get_spatial_size_image_pil, get_size_image_pil,
get_spatial_size_mask, get_size_mask,
get_spatial_size_video, get_size_video,
get_spatial_size, get_size,
) # usort: skip ) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video
......
...@@ -19,6 +19,6 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -19,6 +19,6 @@ def to_tensor(inpt: Any) -> torch.Tensor:
def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
) )
return _F.get_image_size(inpt) return _F.get_image_size(inpt)
...@@ -23,7 +23,7 @@ from torchvision.transforms.functional import ( ...@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_spatial_size_image_pil from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import is_simple_tensor from ._utils import is_simple_tensor
...@@ -52,18 +52,18 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: ...@@ -52,18 +52,18 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def horizontal_flip_bounding_boxes( def horizontal_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_boxes.shape shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4) bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY: if format == datapoints.BoundingBoxFormat.XYXY:
bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(spatial_size[1]).neg_() bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH: elif format == datapoints.BoundingBoxFormat.XYWH:
bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(spatial_size[1]).neg_() bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH: else: # format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 0].sub_(spatial_size[1]).neg_() bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
return bounding_boxes.reshape(shape) return bounding_boxes.reshape(shape)
...@@ -102,18 +102,18 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: ...@@ -102,18 +102,18 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def vertical_flip_bounding_boxes( def vertical_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_boxes.shape shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4) bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY: if format == datapoints.BoundingBoxFormat.XYXY:
bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(spatial_size[0]).neg_() bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH: elif format == datapoints.BoundingBoxFormat.XYWH:
bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(spatial_size[0]).neg_() bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH: else: # format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 1].sub_(spatial_size[0]).neg_() bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
return bounding_boxes.reshape(shape) return bounding_boxes.reshape(shape)
...@@ -146,7 +146,7 @@ vflip = vertical_flip ...@@ -146,7 +146,7 @@ vflip = vertical_flip
def _compute_resized_output_size( def _compute_resized_output_size(
spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]: ) -> List[int]:
if isinstance(size, int): if isinstance(size, int):
size = [size] size = [size]
...@@ -155,7 +155,7 @@ def _compute_resized_output_size( ...@@ -155,7 +155,7 @@ def _compute_resized_output_size(
"max_size should only be passed if size specifies the length of the smaller edge, " "max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode." "i.e. size should be an int or a sequence of length 1 in torchscript mode."
) )
return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
def resize_image_tensor( def resize_image_tensor(
...@@ -275,13 +275,13 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -275,13 +275,13 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_boxes( def resize_bounding_boxes(
bounding_boxes: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = spatial_size old_height, old_width = canvas_size
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
if (new_height, new_width) == (old_height, old_width): if (new_height, new_width) == (old_height, old_width):
return bounding_boxes, spatial_size return bounding_boxes, canvas_size
w_ratio = new_width / old_width w_ratio = new_width / old_width
h_ratio = new_height / old_height h_ratio = new_height / old_height
...@@ -643,7 +643,7 @@ def affine_image_pil( ...@@ -643,7 +643,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None: if center is None:
height, width = get_spatial_size_image_pil(image) height, width = get_size_image_pil(image)
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
...@@ -653,7 +653,7 @@ def affine_image_pil( ...@@ -653,7 +653,7 @@ def affine_image_pil(
def _affine_bounding_boxes_with_expand( def _affine_bounding_boxes_with_expand(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -662,7 +662,7 @@ def _affine_bounding_boxes_with_expand( ...@@ -662,7 +662,7 @@ def _affine_bounding_boxes_with_expand(
expand: bool = False, expand: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
if bounding_boxes.numel() == 0: if bounding_boxes.numel() == 0:
return bounding_boxes, spatial_size return bounding_boxes, canvas_size
original_shape = bounding_boxes.shape original_shape = bounding_boxes.shape
original_dtype = bounding_boxes.dtype original_dtype = bounding_boxes.dtype
...@@ -680,7 +680,7 @@ def _affine_bounding_boxes_with_expand( ...@@ -680,7 +680,7 @@ def _affine_bounding_boxes_with_expand(
) )
if center is None: if center is None:
height, width = spatial_size height, width = canvas_size
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
...@@ -710,7 +710,7 @@ def _affine_bounding_boxes_with_expand( ...@@ -710,7 +710,7 @@ def _affine_bounding_boxes_with_expand(
if expand: if expand:
# Compute minimum point for transformed image frame: # Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = spatial_size height, width = canvas_size
points = torch.tensor( points = torch.tensor(
[ [
[0.0, 0.0, 1.0], [0.0, 0.0, 1.0],
...@@ -728,21 +728,21 @@ def _affine_bounding_boxes_with_expand( ...@@ -728,21 +728,21 @@ def _affine_bounding_boxes_with_expand(
# Estimate meta-data for image with inverted=True and with center=[0,0] # Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height) new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width) canvas_size = (new_height, new_width)
out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size) out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
out_bboxes = convert_format_bounding_boxes( out_bboxes = convert_format_bounding_boxes(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape) ).reshape(original_shape)
out_bboxes = out_bboxes.to(original_dtype) out_bboxes = out_bboxes.to(original_dtype)
return out_bboxes, spatial_size return out_bboxes, canvas_size
def affine_bounding_boxes( def affine_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -752,7 +752,7 @@ def affine_bounding_boxes( ...@@ -752,7 +752,7 @@ def affine_bounding_boxes(
out_box, _ = _affine_bounding_boxes_with_expand( out_box, _ = _affine_bounding_boxes_with_expand(
bounding_boxes, bounding_boxes,
format=format, format=format,
spatial_size=spatial_size, canvas_size=canvas_size,
angle=angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -930,7 +930,7 @@ def rotate_image_pil( ...@@ -930,7 +930,7 @@ def rotate_image_pil(
def rotate_bounding_boxes( def rotate_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
...@@ -941,7 +941,7 @@ def rotate_bounding_boxes( ...@@ -941,7 +941,7 @@ def rotate_bounding_boxes(
return _affine_bounding_boxes_with_expand( return _affine_bounding_boxes_with_expand(
bounding_boxes, bounding_boxes,
format=format, format=format,
spatial_size=spatial_size, canvas_size=canvas_size,
angle=-angle, angle=-angle,
translate=[0.0, 0.0], translate=[0.0, 0.0],
scale=1.0, scale=1.0,
...@@ -1168,7 +1168,7 @@ def pad_mask( ...@@ -1168,7 +1168,7 @@ def pad_mask(
def pad_bounding_boxes( def pad_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
padding: List[int], padding: List[int],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
...@@ -1184,12 +1184,12 @@ def pad_bounding_boxes( ...@@ -1184,12 +1184,12 @@ def pad_bounding_boxes(
pad = [left, top, 0, 0] pad = [left, top, 0, 0]
bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device) bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
height, width = spatial_size height, width = canvas_size
height += top + bottom height += top + bottom
width += left + right width += left + right
spatial_size = (height, width) canvas_size = (height, width)
return clamp_bounding_boxes(bounding_boxes, format=format, spatial_size=spatial_size), spatial_size return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
def pad_video( def pad_video(
...@@ -1261,9 +1261,9 @@ def crop_bounding_boxes( ...@@ -1261,9 +1261,9 @@ def crop_bounding_boxes(
sub = [left, top, 0, 0] sub = [left, top, 0, 0]
bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device) bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
spatial_size = (height, width) canvas_size = (height, width)
return clamp_bounding_boxes(bounding_boxes, format=format, spatial_size=spatial_size), spatial_size return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
...@@ -1412,7 +1412,7 @@ def perspective_image_pil( ...@@ -1412,7 +1412,7 @@ def perspective_image_pil(
def perspective_bounding_boxes( def perspective_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
...@@ -1493,7 +1493,7 @@ def perspective_bounding_boxes( ...@@ -1493,7 +1493,7 @@ def perspective_bounding_boxes(
out_bboxes = clamp_bounding_boxes( out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size, canvas_size=canvas_size,
) )
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
...@@ -1651,7 +1651,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to ...@@ -1651,7 +1651,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def elastic_bounding_boxes( def elastic_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
displacement: torch.Tensor, displacement: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if bounding_boxes.numel() == 0: if bounding_boxes.numel() == 0:
...@@ -1670,7 +1670,7 @@ def elastic_bounding_boxes( ...@@ -1670,7 +1670,7 @@ def elastic_bounding_boxes(
convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid # This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement) inv_grid = id_grid.sub_(displacement)
...@@ -1683,7 +1683,7 @@ def elastic_bounding_boxes( ...@@ -1683,7 +1683,7 @@ def elastic_bounding_boxes(
index_x, index_y = index_xy[:, 0], index_xy[:, 1] index_x, index_y = index_xy[:, 0], index_xy[:, 1]
# Transform points: # Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
transformed_points = transformed_points.reshape(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
...@@ -1691,7 +1691,7 @@ def elastic_bounding_boxes( ...@@ -1691,7 +1691,7 @@ def elastic_bounding_boxes(
out_bboxes = clamp_bounding_boxes( out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype), torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY, format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size, canvas_size=canvas_size,
) )
return convert_format_bounding_boxes( return convert_format_bounding_boxes(
...@@ -1804,13 +1804,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor ...@@ -1804,13 +1804,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@torch.jit.unused @torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_spatial_size_image_pil(image) image_height, image_width = get_size_image_pil(image)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0) image = pad_image_pil(image, padding_ltrb, fill=0)
image_height, image_width = get_spatial_size_image_pil(image) image_height, image_width = get_size_image_pil(image)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return image return image
...@@ -1821,11 +1821,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL ...@@ -1821,11 +1821,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL
def center_crop_bounding_boxes( def center_crop_bounding_boxes(
bounding_boxes: torch.Tensor, bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], canvas_size: Tuple[int, int],
output_size: List[int], output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
return crop_bounding_boxes( return crop_bounding_boxes(
bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
) )
...@@ -1905,7 +1905,7 @@ def resized_crop_bounding_boxes( ...@@ -1905,7 +1905,7 @@ def resized_crop_bounding_boxes(
size: List[int], size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_boxes, _ = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) bounding_boxes, _ = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
return resize_bounding_boxes(bounding_boxes, spatial_size=(height, width), size=size) return resize_bounding_boxes(bounding_boxes, canvas_size=(height, width), size=size)
def resized_crop_mask( def resized_crop_mask(
...@@ -2000,7 +2000,7 @@ def five_crop_image_pil( ...@@ -2000,7 +2000,7 @@ def five_crop_image_pil(
image: PIL.Image.Image, size: List[int] image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_spatial_size_image_pil(image) image_height, image_width = get_size_image_pil(image)
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
......
...@@ -26,19 +26,25 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -26,19 +26,25 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions) _log_api_usage_once(get_dimensions)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt) return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
channels = inpt.num_channels for typ, get_size_fn in {
height, width = inpt.spatial_size datapoints.Image: get_dimensions_image_tensor,
return [channels, height, width] datapoints.Video: get_dimensions_video,
elif isinstance(inpt, PIL.Image.Image): PIL.Image.Image: get_dimensions_image_pil,
return get_dimensions_image_pil(inpt) }.items():
else: if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
...@@ -69,11 +75,15 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType ...@@ -69,11 +75,15 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_channels_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.num_channels for typ, get_size_fn in {
elif isinstance(inpt, PIL.Image.Image): datapoints.Image: get_num_channels_image_tensor,
return get_num_channels_image_pil(inpt) datapoints.Video: get_num_channels_video,
else: PIL.Image.Image: get_num_channels_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
...@@ -85,7 +95,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType ...@@ -85,7 +95,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType
get_image_num_channels = get_num_channels get_image_num_channels = get_num_channels
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: def get_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:]) hw = list(image.shape[-2:])
ndims = len(hw) ndims = len(hw)
if ndims == 2: if ndims == 2:
...@@ -95,35 +105,44 @@ def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]: ...@@ -95,35 +105,44 @@ def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
@torch.jit.unused @torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image) width, height = _FP.get_image_size(image)
return [height, width] return [height, width]
def get_spatial_size_video(video: torch.Tensor) -> List[int]: def get_size_video(video: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(video) return get_size_image_tensor(video)
def get_spatial_size_mask(mask: torch.Tensor) -> List[int]: def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(mask) return get_size_image_tensor(mask)
@torch.jit.unused @torch.jit.unused
def get_spatial_size_bounding_boxes(bounding_boxes: datapoints.BoundingBoxes) -> List[int]: def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]:
return list(bounding_boxes.spatial_size) return list(bounding_box.canvas_size)
def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]: def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(get_spatial_size) _log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_spatial_size_image_tensor(inpt) return get_size_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBoxes, datapoints.Mask)):
return list(inpt.spatial_size) # TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with
elif isinstance(inpt, PIL.Image.Image): # https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have
return get_spatial_size_image_pil(inpt) # a method on the datapoint class
else: for typ, get_size_fn in {
datapoints.Image: get_size_image_tensor,
datapoints.BoundingBoxes: get_size_bounding_boxes,
datapoints.Mask: get_size_mask,
datapoints.Video: get_size_video,
PIL.Image.Image: get_size_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
...@@ -141,7 +160,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: ...@@ -141,7 +160,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video): elif isinstance(inpt, datapoints.Video):
return inpt.num_frames return get_num_frames_video(inpt)
else: else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.") raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
...@@ -240,7 +259,7 @@ def convert_format_bounding_boxes( ...@@ -240,7 +259,7 @@ def convert_format_bounding_boxes(
def _clamp_bounding_boxes( def _clamp_bounding_boxes(
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth # BoundingBoxFormat instead of converting back and forth
...@@ -249,8 +268,8 @@ def _clamp_bounding_boxes( ...@@ -249,8 +268,8 @@ def _clamp_bounding_boxes(
xyxy_boxes = convert_format_bounding_boxes( xyxy_boxes = convert_format_bounding_boxes(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
) )
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
out_boxes = convert_format_bounding_boxes( out_boxes = convert_format_bounding_boxes(
xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
) )
...@@ -260,21 +279,20 @@ def _clamp_bounding_boxes( ...@@ -260,21 +279,20 @@ def _clamp_bounding_boxes(
def clamp_bounding_boxes( def clamp_bounding_boxes(
inpt: datapoints._InputTypeJIT, inpt: datapoints._InputTypeJIT,
format: Optional[BoundingBoxFormat] = None, format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None, canvas_size: Optional[Tuple[int, int]] = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_boxes) _log_api_usage_once(clamp_bounding_boxes)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
if format is None or spatial_size is None:
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.") if format is None or canvas_size is None:
return _clamp_bounding_boxes(inpt, format=format, spatial_size=spatial_size) raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, datapoints.BoundingBoxes):
if format is not None or spatial_size is not None: if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.") raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes( output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output) return datapoints.BoundingBoxes.wrap_like(inpt, output)
else: else:
raise TypeError( raise TypeError(
......
...@@ -6,15 +6,15 @@ import PIL.Image ...@@ -6,15 +6,15 @@ import PIL.Image
from torchvision import datapoints from torchvision import datapoints
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor
def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)] bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)]
if not bounding_boxes: if not bounding_boxes:
raise TypeError("No bounding box was found in the sample") raise TypeError("No bounding boxes were found in the sample")
elif len(bounding_boxes) > 1: elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample") raise ValueError("Found multiple bounding boxes instances in the sample")
return bounding_boxes.pop() return bounding_boxes.pop()
...@@ -22,7 +22,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: ...@@ -22,7 +22,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = { chws = {
tuple(get_dimensions(inpt)) tuple(get_dimensions(inpt))
for inpt in flat_inputs for inpt in flat_inputs
if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt) if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
} }
if not chws: if not chws:
raise TypeError("No image or video was found in the sample") raise TypeError("No image or video was found in the sample")
...@@ -32,14 +32,21 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: ...@@ -32,14 +32,21 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
return c, h, w return c, h, w
def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]: def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = { sizes = {
tuple(get_spatial_size(inpt)) tuple(get_size(inpt))
for inpt in flat_inputs for inpt in flat_inputs
if isinstance( if check_type(
inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBoxes) inpt,
(
is_simple_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
),
) )
or is_simple_tensor(inpt)
} }
if not sizes: if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample") raise TypeError("No image, video, mask or bounding box was found in the sample")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment