Unverified Commit 312c3d32 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove spatial_size (#7734)

parent bdf16222
......@@ -9,7 +9,7 @@ from torchvision import datapoints, transforms as _transforms
from torchvision.transforms import _functional_tensor as _FT
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_spatial_size
from torchvision.transforms.v2.functional._meta import get_size
from ._utils import _setup_fill_arg
from .utils import check_type, is_simple_tensor
......@@ -312,7 +312,7 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
height, width = get_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -403,7 +403,7 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
height, width = get_size(image_or_video)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -474,7 +474,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
height, width = get_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -568,7 +568,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(orig_image_or_video)
height, width = get_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
......
......@@ -22,7 +22,7 @@ from ._utils import (
_setup_float_or_seq,
_setup_size,
)
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_spatial_size
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size
class RandomHorizontalFlip(_RandomApplyTransform):
......@@ -267,7 +267,7 @@ class RandomResizedCrop(Transform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
height, width = query_size(flat_inputs)
area = height * width
log_ratio = self._log_ratio
......@@ -558,7 +558,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
orig_h, orig_w = query_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
......@@ -735,7 +735,7 @@ class RandomAffine(Transform):
self.center = center
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
height, width = query_size(flat_inputs)
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
if self.translate is not None:
......@@ -859,7 +859,7 @@ class RandomCrop(Transform):
self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(flat_inputs)
padded_height, padded_width = query_size(flat_inputs)
if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding
......@@ -972,7 +972,7 @@ class RandomPerspective(_RandomApplyTransform):
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
height, width = query_size(flat_inputs)
distortion_scale = self.distortion_scale
......@@ -1072,7 +1072,7 @@ class ElasticTransform(Transform):
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
size = list(query_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
......@@ -1164,7 +1164,7 @@ class RandomIoUCrop(Transform):
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
orig_h, orig_w = query_size(flat_inputs)
bboxes = query_bounding_boxes(flat_inputs)
while True:
......@@ -1282,7 +1282,7 @@ class ScaleJitter(Transform):
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
orig_height, orig_width = query_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
......@@ -1347,7 +1347,7 @@ class RandomShortestSize(Transform):
self.antialias = antialias
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
orig_height, orig_width = query_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min_size / min(orig_height, orig_width)
......
......@@ -30,7 +30,7 @@ class ConvertBoundingBoxFormat(Transform):
class ClampBoundingBoxes(Transform):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
The clamping is done according to the bounding boxes' ``spatial_size`` meta-data.
The clamping is done according to the bounding boxes' ``canvas_size`` meta-data.
.. v2betastatus:: ClampBoundingBoxes transform
......
......@@ -408,7 +408,7 @@ class SanitizeBoundingBoxes(Transform):
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
image_h, image_w = boxes.canvas_size
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
......
......@@ -15,12 +15,12 @@ from ._meta import (
get_num_channels_image_pil,
get_num_channels_video,
get_num_channels,
get_spatial_size_bounding_boxes,
get_spatial_size_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_mask,
get_spatial_size_video,
get_spatial_size,
get_size_bounding_boxes,
get_size_image_tensor,
get_size_image_pil,
get_size_mask,
get_size_video,
get_size,
) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video
......
......@@ -19,6 +19,6 @@ def to_tensor(inpt: Any) -> torch.Tensor:
def get_image_size(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
"Instead, please use `get_size(...)` which returns `[h, w]` instead of `[w, h]`."
)
return _F.get_image_size(inpt)
......@@ -23,7 +23,7 @@ from torchvision.transforms.functional import (
from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_spatial_size_image_pil
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import is_simple_tensor
......@@ -52,18 +52,18 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def horizontal_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY:
bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(spatial_size[1]).neg_()
bounding_boxes[:, [2, 0]] = bounding_boxes[:, [0, 2]].sub_(canvas_size[1]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH:
bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(spatial_size[1]).neg_()
bounding_boxes[:, 0].add_(bounding_boxes[:, 2]).sub_(canvas_size[1]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 0].sub_(spatial_size[1]).neg_()
bounding_boxes[:, 0].sub_(canvas_size[1]).neg_()
return bounding_boxes.reshape(shape)
......@@ -102,18 +102,18 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def vertical_flip_bounding_boxes(
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, spatial_size: Tuple[int, int]
bounding_boxes: torch.Tensor, format: datapoints.BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor:
shape = bounding_boxes.shape
bounding_boxes = bounding_boxes.clone().reshape(-1, 4)
if format == datapoints.BoundingBoxFormat.XYXY:
bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(spatial_size[0]).neg_()
bounding_boxes[:, [1, 3]] = bounding_boxes[:, [3, 1]].sub_(canvas_size[0]).neg_()
elif format == datapoints.BoundingBoxFormat.XYWH:
bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(spatial_size[0]).neg_()
bounding_boxes[:, 1].add_(bounding_boxes[:, 3]).sub_(canvas_size[0]).neg_()
else: # format == datapoints.BoundingBoxFormat.CXCYWH:
bounding_boxes[:, 1].sub_(spatial_size[0]).neg_()
bounding_boxes[:, 1].sub_(canvas_size[0]).neg_()
return bounding_boxes.reshape(shape)
......@@ -146,7 +146,7 @@ vflip = vertical_flip
def _compute_resized_output_size(
spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if isinstance(size, int):
size = [size]
......@@ -155,7 +155,7 @@ def _compute_resized_output_size(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
def resize_image_tensor(
......@@ -275,13 +275,13 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_boxes(
bounding_boxes: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = spatial_size
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
old_height, old_width = canvas_size
new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
if (new_height, new_width) == (old_height, old_width):
return bounding_boxes, spatial_size
return bounding_boxes, canvas_size
w_ratio = new_width / old_width
h_ratio = new_height / old_height
......@@ -643,7 +643,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
height, width = get_spatial_size_image_pil(image)
height, width = get_size_image_pil(image)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
......@@ -653,7 +653,7 @@ def affine_image_pil(
def _affine_bounding_boxes_with_expand(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
scale: float,
......@@ -662,7 +662,7 @@ def _affine_bounding_boxes_with_expand(
expand: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if bounding_boxes.numel() == 0:
return bounding_boxes, spatial_size
return bounding_boxes, canvas_size
original_shape = bounding_boxes.shape
original_dtype = bounding_boxes.dtype
......@@ -680,7 +680,7 @@ def _affine_bounding_boxes_with_expand(
)
if center is None:
height, width = spatial_size
height, width = canvas_size
center = [width * 0.5, height * 0.5]
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
......@@ -710,7 +710,7 @@ def _affine_bounding_boxes_with_expand(
if expand:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = spatial_size
height, width = canvas_size
points = torch.tensor(
[
[0.0, 0.0, 1.0],
......@@ -728,21 +728,21 @@ def _affine_bounding_boxes_with_expand(
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width)
canvas_size = (new_height, new_width)
out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
out_bboxes = clamp_bounding_boxes(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size)
out_bboxes = convert_format_bounding_boxes(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
out_bboxes = out_bboxes.to(original_dtype)
return out_bboxes, spatial_size
return out_bboxes, canvas_size
def affine_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
angle: Union[int, float],
translate: List[float],
scale: float,
......@@ -752,7 +752,7 @@ def affine_bounding_boxes(
out_box, _ = _affine_bounding_boxes_with_expand(
bounding_boxes,
format=format,
spatial_size=spatial_size,
canvas_size=canvas_size,
angle=angle,
translate=translate,
scale=scale,
......@@ -930,7 +930,7 @@ def rotate_image_pil(
def rotate_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
......@@ -941,7 +941,7 @@ def rotate_bounding_boxes(
return _affine_bounding_boxes_with_expand(
bounding_boxes,
format=format,
spatial_size=spatial_size,
canvas_size=canvas_size,
angle=-angle,
translate=[0.0, 0.0],
scale=1.0,
......@@ -1168,7 +1168,7 @@ def pad_mask(
def pad_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
padding: List[int],
padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]:
......@@ -1184,12 +1184,12 @@ def pad_bounding_boxes(
pad = [left, top, 0, 0]
bounding_boxes = bounding_boxes + torch.tensor(pad, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
height, width = spatial_size
height, width = canvas_size
height += top + bottom
width += left + right
spatial_size = (height, width)
canvas_size = (height, width)
return clamp_bounding_boxes(bounding_boxes, format=format, spatial_size=spatial_size), spatial_size
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
def pad_video(
......@@ -1261,9 +1261,9 @@ def crop_bounding_boxes(
sub = [left, top, 0, 0]
bounding_boxes = bounding_boxes - torch.tensor(sub, dtype=bounding_boxes.dtype, device=bounding_boxes.device)
spatial_size = (height, width)
canvas_size = (height, width)
return clamp_bounding_boxes(bounding_boxes, format=format, spatial_size=spatial_size), spatial_size
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
......@@ -1412,7 +1412,7 @@ def perspective_image_pil(
def perspective_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None,
......@@ -1493,7 +1493,7 @@ def perspective_bounding_boxes(
out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
canvas_size=canvas_size,
)
# out_bboxes should be of shape [N boxes, 4]
......@@ -1651,7 +1651,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def elastic_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
displacement: torch.Tensor,
) -> torch.Tensor:
if bounding_boxes.numel() == 0:
......@@ -1670,7 +1670,7 @@ def elastic_bounding_boxes(
convert_format_bounding_boxes(bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4)
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid.sub_(displacement)
......@@ -1683,7 +1683,7 @@ def elastic_bounding_boxes(
index_x, index_y = index_xy[:, 0], index_xy[:, 1]
# Transform points:
t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
transformed_points = transformed_points.reshape(-1, 4, 2)
......@@ -1691,7 +1691,7 @@ def elastic_bounding_boxes(
out_bboxes = clamp_bounding_boxes(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_boxes.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
canvas_size=canvas_size,
)
return convert_format_bounding_boxes(
......@@ -1804,13 +1804,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
@torch.jit.unused
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_spatial_size_image_pil(image)
image_height, image_width = get_size_image_pil(image)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_pil(image, padding_ltrb, fill=0)
image_height, image_width = get_spatial_size_image_pil(image)
image_height, image_width = get_size_image_pil(image)
if crop_width == image_width and crop_height == image_height:
return image
......@@ -1821,11 +1821,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL
def center_crop_bounding_boxes(
bounding_boxes: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
canvas_size: Tuple[int, int],
output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
return crop_bounding_boxes(
bounding_boxes, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width
)
......@@ -1905,7 +1905,7 @@ def resized_crop_bounding_boxes(
size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_boxes, _ = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
return resize_bounding_boxes(bounding_boxes, spatial_size=(height, width), size=size)
return resize_bounding_boxes(bounding_boxes, canvas_size=(height, width), size=size)
def resized_crop_mask(
......@@ -2000,7 +2000,7 @@ def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_spatial_size_image_pil(image)
image_height, image_width = get_size_image_pil(image)
if crop_width > image_width or crop_height > image_height:
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
......
......@@ -26,23 +26,29 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video)
def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_dimensions)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_dimensions_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
channels = inpt.num_channels
height, width = inpt.spatial_size
return [channels, height, width]
elif isinstance(inpt, PIL.Image.Image):
return get_dimensions_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
for typ, get_size_fn in {
datapoints.Image: get_dimensions_image_tensor,
datapoints.Video: get_dimensions_video,
PIL.Image.Image: get_dimensions_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_channels_image_tensor(image: torch.Tensor) -> int:
......@@ -69,15 +75,19 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_channels_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt.num_channels
elif isinstance(inpt, PIL.Image.Image):
return get_num_channels_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
for typ, get_size_fn in {
datapoints.Image: get_num_channels_image_tensor,
datapoints.Video: get_num_channels_video,
PIL.Image.Image: get_num_channels_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
......@@ -85,7 +95,7 @@ def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoType
get_image_num_channels = get_num_channels
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
def get_size_image_tensor(image: torch.Tensor) -> List[int]:
hw = list(image.shape[-2:])
ndims = len(hw)
if ndims == 2:
......@@ -95,39 +105,48 @@ def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
@torch.jit.unused
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
def get_size_image_pil(image: PIL.Image.Image) -> List[int]:
width, height = _FP.get_image_size(image)
return [height, width]
def get_spatial_size_video(video: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(video)
def get_size_video(video: torch.Tensor) -> List[int]:
return get_size_image_tensor(video)
def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(mask)
def get_size_mask(mask: torch.Tensor) -> List[int]:
return get_size_image_tensor(mask)
@torch.jit.unused
def get_spatial_size_bounding_boxes(bounding_boxes: datapoints.BoundingBoxes) -> List[int]:
return list(bounding_boxes.spatial_size)
def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]:
return list(bounding_box.canvas_size)
def get_spatial_size(inpt: datapoints._InputTypeJIT) -> List[int]:
def get_size(inpt: datapoints._InputTypeJIT) -> List[int]:
if not torch.jit.is_scripting():
_log_api_usage_once(get_spatial_size)
_log_api_usage_once(get_size)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, (datapoints.Image, datapoints.Video, datapoints.BoundingBoxes, datapoints.Mask)):
return list(inpt.spatial_size)
elif isinstance(inpt, PIL.Image.Image):
return get_spatial_size_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
return get_size_image_tensor(inpt)
# TODO: This is just the poor mans version of a dispatcher. This will be properly addressed with
# https://github.com/pytorch/vision/pull/7747 when we can register the kernels above without the need to have
# a method on the datapoint class
for typ, get_size_fn in {
datapoints.Image: get_size_image_tensor,
datapoints.BoundingBoxes: get_size_bounding_boxes,
datapoints.Mask: get_size_mask,
datapoints.Video: get_size_video,
PIL.Image.Image: get_size_image_pil,
}.items():
if isinstance(inpt, typ):
return get_size_fn(inpt)
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def get_num_frames_video(video: torch.Tensor) -> int:
......@@ -141,7 +160,7 @@ def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return get_num_frames_video(inpt)
elif isinstance(inpt, datapoints.Video):
return inpt.num_frames
return get_num_frames_video(inpt)
else:
raise TypeError(f"Input can either be a plain tensor or a `Video` datapoint, but got {type(inpt)} instead.")
......@@ -240,7 +259,7 @@ def convert_format_bounding_boxes(
def _clamp_bounding_boxes(
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
bounding_boxes: torch.Tensor, format: BoundingBoxFormat, canvas_size: Tuple[int, int]
) -> torch.Tensor:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
......@@ -249,8 +268,8 @@ def _clamp_bounding_boxes(
xyxy_boxes = convert_format_bounding_boxes(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
xyxy_boxes[..., 0::2].clamp_(min=0, max=canvas_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=canvas_size[0])
out_boxes = convert_format_bounding_boxes(
xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
......@@ -260,21 +279,20 @@ def _clamp_bounding_boxes(
def clamp_bounding_boxes(
inpt: datapoints._InputTypeJIT,
format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None,
canvas_size: Optional[Tuple[int, int]] = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_boxes)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if format is None or spatial_size is None:
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, spatial_size=spatial_size)
if format is None or canvas_size is None:
raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes):
if format is not None or spatial_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
output = _clamp_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
if format is not None or canvas_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.")
output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
else:
raise TypeError(
......
......@@ -6,15 +6,15 @@ import PIL.Image
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.transforms.v2.functional import get_dimensions, get_spatial_size, is_simple_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor
def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)]
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
raise TypeError("No bounding boxes were found in the sample")
elif len(bounding_boxes) > 1:
raise ValueError("Found multiple bounding boxes in the sample")
raise ValueError("Found multiple bounding boxes instances in the sample")
return bounding_boxes.pop()
......@@ -22,7 +22,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if isinstance(inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video)) or is_simple_tensor(inpt)
if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
......@@ -32,14 +32,21 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
return c, h, w
def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_spatial_size(inpt))
tuple(get_size(inpt))
for inpt in flat_inputs
if isinstance(
inpt, (datapoints.Image, PIL.Image.Image, datapoints.Video, datapoints.Mask, datapoints.BoundingBoxes)
if check_type(
inpt,
(
is_simple_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
),
)
or is_simple_tensor(inpt)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment