Unverified Commit 4d4711d9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Switch to `spatial_size` (#6736)

* Change `image_size` to `spatial_size`

* Fix linter

* Fixing more tests.

* Adding get_num_channels_video and get_spatial_size_* kernels for video, masks and bboxes.

* Refactor get_spatial_size

* Reduce the usage of `query_chw` where possible

* Rename `query_chw` to `query_spatial_size`

* Adding `get_num_frames` dispatcher and kernel.

* Adding jit-scriptability tests
parent 3099e0cc
...@@ -78,7 +78,7 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -78,7 +78,7 @@ class RandomGrayscale(_RandomApplyTransform):
super().__init__(p=p) super().__init__(p=p)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, _, _ = query_chw(sample) num_input_channels, *_ = query_chw(sample)
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType:
......
...@@ -24,7 +24,7 @@ from ._utils import ( ...@@ -24,7 +24,7 @@ from ._utils import (
has_all, has_all,
has_any, has_any,
query_bounding_box, query_bounding_box,
query_chw, query_spatial_size,
) )
...@@ -105,10 +105,7 @@ class RandomResizedCrop(Transform): ...@@ -105,10 +105,7 @@ class RandomResizedCrop(Transform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples height, width = query_spatial_size(sample)
# What if we have multiple images/bboxes/masks of different sizes ?
# TODO: let's support bbox or mask in samples without image
_, height, width = query_chw(sample)
area = height * width area = height * width
log_ratio = self._log_ratio log_ratio = self._log_ratio
...@@ -263,7 +260,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -263,7 +260,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_h, orig_w = query_chw(sample) orig_h, orig_w = query_spatial_size(sample)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -362,10 +359,7 @@ class RandomAffine(Transform): ...@@ -362,10 +359,7 @@ class RandomAffine(Transform):
self.center = center self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample)
# Get image size
# TODO: make it work with bboxes and segm masks
_, height, width = query_chw(sample)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None: if self.translate is not None:
...@@ -427,7 +421,7 @@ class RandomCrop(Transform): ...@@ -427,7 +421,7 @@ class RandomCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, padded_height, padded_width = query_chw(sample) padded_height, padded_width = query_spatial_size(sample)
if self.padding is not None: if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding pad_left, pad_right, pad_top, pad_bottom = self.padding
...@@ -515,9 +509,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -515,9 +509,7 @@ class RandomPerspective(_RandomApplyTransform):
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size height, width = query_spatial_size(sample)
# TODO: make it work with bboxes and segm masks
_, height, width = query_chw(sample)
distortion_scale = self.distortion_scale distortion_scale = self.distortion_scale
...@@ -571,9 +563,7 @@ class ElasticTransform(Transform): ...@@ -571,9 +563,7 @@ class ElasticTransform(Transform):
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size size = list(query_spatial_size(sample))
# TODO: make it work with bboxes and segm masks
_, *size = query_chw(sample)
dx = torch.rand([1, 1] + size) * 2 - 1 dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0: if self.sigma[0] > 0.0:
...@@ -628,7 +618,7 @@ class RandomIoUCrop(Transform): ...@@ -628,7 +618,7 @@ class RandomIoUCrop(Transform):
self.trials = trials self.trials = trials
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_h, orig_w = query_chw(sample) orig_h, orig_w = query_spatial_size(sample)
bboxes = query_bounding_box(sample) bboxes = query_bounding_box(sample)
while True: while True:
...@@ -690,7 +680,7 @@ class RandomIoUCrop(Transform): ...@@ -690,7 +680,7 @@ class RandomIoUCrop(Transform):
if isinstance(output, features.BoundingBox): if isinstance(output, features.BoundingBox):
bboxes = output[is_within_crop_area] bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size)
output = features.BoundingBox.wrap_like(output, bboxes) output = features.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, features.Mask): elif isinstance(output, features.Mask):
# apply is_within_crop_area if mask is one-hot encoded # apply is_within_crop_area if mask is one-hot encoded
...@@ -727,7 +717,7 @@ class ScaleJitter(Transform): ...@@ -727,7 +717,7 @@ class ScaleJitter(Transform):
self.antialias = antialias self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample) orig_height, orig_width = query_spatial_size(sample)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
...@@ -755,7 +745,7 @@ class RandomShortestSize(Transform): ...@@ -755,7 +745,7 @@ class RandomShortestSize(Transform):
self.antialias = antialias self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_height, orig_width = query_chw(sample) orig_height, orig_width = query_spatial_size(sample)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
...@@ -786,7 +776,7 @@ class FixedSizeCrop(Transform): ...@@ -786,7 +776,7 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, height, width = query_chw(sample) height, width = query_spatial_size(sample)
new_height = min(height, self.crop_height) new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width) new_width = min(width, self.crop_width)
...@@ -811,7 +801,7 @@ class FixedSizeCrop(Transform): ...@@ -811,7 +801,7 @@ class FixedSizeCrop(Transform):
bounding_boxes = features.BoundingBox.wrap_like( bounding_boxes = features.BoundingBox.wrap_like(
bounding_boxes, bounding_boxes,
F.clamp_bounding_box( F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size bounding_boxes, format=bounding_boxes.format, spatial_size=bounding_boxes.spatial_size
), ),
) )
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:] height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
...@@ -851,7 +841,7 @@ class FixedSizeCrop(Transform): ...@@ -851,7 +841,7 @@ class FixedSizeCrop(Transform):
elif isinstance(inpt, features.BoundingBox): elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.wrap_like( inpt = features.BoundingBox.wrap_like(
inpt, inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size), F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size),
) )
if params["needs_pad"]: if params["needs_pad"]:
......
...@@ -68,5 +68,5 @@ class ClampBoundingBoxes(Transform): ...@@ -68,5 +68,5 @@ class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,) _transformed_types = (features.BoundingBox,)
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size) output = F.clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size)
return features.BoundingBox.wrap_like(inpt, output) return features.BoundingBox.wrap_like(inpt, output)
...@@ -10,7 +10,7 @@ from torchvision._utils import sequence_to_str ...@@ -10,7 +10,7 @@ from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType from torchvision.prototype.features._feature import FillType
from torchvision.prototype.transforms.functional._meta import get_dimensions from torchvision.prototype.transforms.functional._meta import get_dimensions, get_spatial_size
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from typing_extensions import Literal from typing_extensions import Literal
...@@ -98,6 +98,22 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -98,6 +98,22 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
return c, h, w return c, h, w
def query_spatial_size(sample: Any) -> Tuple[int, int]:
flat_sample, _ = tree_flatten(sample)
sizes = {
tuple(get_spatial_size(item))
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(item)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool: def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks: for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
......
...@@ -11,12 +11,18 @@ from ._meta import ( ...@@ -11,12 +11,18 @@ from ._meta import (
get_dimensions_image_tensor, get_dimensions_image_tensor,
get_dimensions_image_pil, get_dimensions_image_pil,
get_dimensions, get_dimensions,
get_num_frames_video,
get_num_frames,
get_image_num_channels, get_image_num_channels,
get_num_channels_image_tensor, get_num_channels_image_tensor,
get_num_channels_image_pil, get_num_channels_image_pil,
get_num_channels_video,
get_num_channels, get_num_channels,
get_spatial_size_bounding_box,
get_spatial_size_image_tensor, get_spatial_size_image_tensor,
get_spatial_size_image_pil, get_spatial_size_image_pil,
get_spatial_size_mask,
get_spatial_size_video,
get_spatial_size, get_spatial_size,
) # usort: skip ) # usort: skip
......
...@@ -32,7 +32,7 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: ...@@ -32,7 +32,7 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def horizontal_flip_bounding_box( def horizontal_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] bounding_box: torch.Tensor, format: features.BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_box.shape shape = bounding_box.shape
...@@ -40,7 +40,7 @@ def horizontal_flip_bounding_box( ...@@ -40,7 +40,7 @@ def horizontal_flip_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).view(-1, 4)
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]
return convert_format_bounding_box( return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
...@@ -69,7 +69,7 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: ...@@ -69,7 +69,7 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
def vertical_flip_bounding_box( def vertical_flip_bounding_box(
bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] bounding_box: torch.Tensor, format: features.BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_box.shape shape = bounding_box.shape
...@@ -77,7 +77,7 @@ def vertical_flip_bounding_box( ...@@ -77,7 +77,7 @@ def vertical_flip_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).view(-1, 4)
bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]
return convert_format_bounding_box( return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
...@@ -104,11 +104,11 @@ vflip = vertical_flip ...@@ -104,11 +104,11 @@ vflip = vertical_flip
def _compute_resized_output_size( def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]: ) -> List[int]:
if isinstance(size, int): if isinstance(size, int):
size = [size] size = [size]
return __compute_resized_output_size(image_size, size=size, max_size=max_size) return __compute_resized_output_size(spatial_size, size=size, max_size=max_size)
def resize_image_tensor( def resize_image_tensor(
...@@ -162,10 +162,10 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -162,10 +162,10 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_box( def resize_bounding_box(
bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = image_size old_height, old_width = spatial_size
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return ( return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape), bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape),
...@@ -312,7 +312,7 @@ def affine_image_pil( ...@@ -312,7 +312,7 @@ def affine_image_pil(
def _affine_bounding_box_xyxy( def _affine_bounding_box_xyxy(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
image_size: Tuple[int, int], spatial_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -325,7 +325,7 @@ def _affine_bounding_box_xyxy( ...@@ -325,7 +325,7 @@ def _affine_bounding_box_xyxy(
) )
if center is None: if center is None:
height, width = image_size height, width = spatial_size
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
...@@ -359,7 +359,7 @@ def _affine_bounding_box_xyxy( ...@@ -359,7 +359,7 @@ def _affine_bounding_box_xyxy(
if expand: if expand:
# Compute minimum point for transformed image frame: # Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = image_size height, width = spatial_size
points = torch.tensor( points = torch.tensor(
[ [
[0.0, 0.0, 1.0], [0.0, 0.0, 1.0],
...@@ -378,15 +378,15 @@ def _affine_bounding_box_xyxy( ...@@ -378,15 +378,15 @@ def _affine_bounding_box_xyxy(
# Estimate meta-data for image with inverted=True and with center=[0,0] # Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height) new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height)
image_size = (new_height, new_width) spatial_size = (new_height, new_width)
return out_bboxes.to(bounding_box.dtype), image_size return out_bboxes.to(bounding_box.dtype), spatial_size
def affine_bounding_box( def affine_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
image_size: Tuple[int, int], spatial_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -398,7 +398,7 @@ def affine_bounding_box( ...@@ -398,7 +398,7 @@ def affine_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).view(-1, 4)
out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center) out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
...@@ -573,7 +573,7 @@ def rotate_image_pil( ...@@ -573,7 +573,7 @@ def rotate_image_pil(
def rotate_bounding_box( def rotate_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
image_size: Tuple[int, int], spatial_size: Tuple[int, int],
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
...@@ -587,9 +587,9 @@ def rotate_bounding_box( ...@@ -587,9 +587,9 @@ def rotate_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4) ).view(-1, 4)
out_bboxes, image_size = _affine_bounding_box_xyxy( out_bboxes, spatial_size = _affine_bounding_box_xyxy(
bounding_box, bounding_box,
image_size, spatial_size,
angle=-angle, angle=-angle,
translate=[0.0, 0.0], translate=[0.0, 0.0],
scale=1.0, scale=1.0,
...@@ -602,7 +602,7 @@ def rotate_bounding_box( ...@@ -602,7 +602,7 @@ def rotate_bounding_box(
convert_format_bounding_box( convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape), ).view(original_shape),
image_size, spatial_size,
) )
...@@ -756,7 +756,7 @@ def pad_mask( ...@@ -756,7 +756,7 @@ def pad_mask(
def pad_bounding_box( def pad_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
image_size: Tuple[int, int], spatial_size: Tuple[int, int],
padding: Union[int, List[int]], padding: Union[int, List[int]],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
...@@ -775,7 +775,7 @@ def pad_bounding_box( ...@@ -775,7 +775,7 @@ def pad_bounding_box(
bounding_box[..., 2] += left bounding_box[..., 2] += left
bounding_box[..., 3] += top bounding_box[..., 3] += top
height, width = image_size height, width = spatial_size
height += top + bottom height += top + bottom
width += left + right width += left + right
...@@ -1066,10 +1066,10 @@ def elastic_bounding_box( ...@@ -1066,10 +1066,10 @@ def elastic_bounding_box(
).view(-1, 4) ).view(-1, 4)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add image_size arg and check displacement shape # Or add spatial_size arg and check displacement shape
image_size = displacement.shape[-3], displacement.shape[-2] spatial_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device) id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid # This is not an exact inverse of the grid
inv_grid = id_grid - displacement inv_grid = id_grid - displacement
...@@ -1079,7 +1079,7 @@ def elastic_bounding_box( ...@@ -1079,7 +1079,7 @@ def elastic_bounding_box(
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
# Transform points: # Transform points:
t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype) t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5
transformed_points = transformed_points.view(-1, 4, 2) transformed_points = transformed_points.view(-1, 4, 2)
...@@ -1199,11 +1199,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL ...@@ -1199,11 +1199,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL
def center_crop_bounding_box( def center_crop_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
image_size: Tuple[int, int], spatial_size: Tuple[int, int],
output_size: List[int], output_size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width) return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
......
...@@ -18,7 +18,7 @@ def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: ...@@ -18,7 +18,7 @@ def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
return get_dimensions_image_tensor(image) return get_dimensions_image_tensor(image)
elif isinstance(image, (features.Image, features.Video)): elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels channels = image.num_channels
height, width = image.image_size height, width = image.spatial_size
return [channels, height, width] return [channels, height, width]
else: else:
return get_dimensions_image_pil(image) return get_dimensions_image_pil(image)
...@@ -28,6 +28,10 @@ get_num_channels_image_tensor = _FT.get_image_num_channels ...@@ -28,6 +28,10 @@ get_num_channels_image_tensor = _FT.get_image_num_channels
get_num_channels_image_pil = _FP.get_image_num_channels get_num_channels_image_pil = _FP.get_image_num_channels
def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video)
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
if isinstance(image, torch.Tensor) and ( if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
...@@ -55,21 +59,39 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]: ...@@ -55,21 +59,39 @@ def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
return [height, width] return [height, width]
# TODO: Should we have get_spatial_size_video here? How about masks/bbox etc? What is the criterion for deciding when def get_spatial_size_video(video: torch.Tensor) -> List[int]:
# a kernel will be created? return get_spatial_size_image_tensor(video)
def get_spatial_size_mask(mask: torch.Tensor) -> List[int]:
return get_spatial_size_image_tensor(mask)
@torch.jit.unused
def get_spatial_size_bounding_box(bounding_box: features.BoundingBox) -> List[int]:
return list(bounding_box.spatial_size)
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
return get_spatial_size_image_tensor(inpt) return get_spatial_size_image_tensor(inpt)
elif isinstance(inpt, features._Feature): elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
image_size = getattr(inpt, "image_size", None) return list(inpt.spatial_size)
if image_size is not None:
return list(image_size)
else: else:
raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.") return get_spatial_size_image_pil(inpt) # type: ignore[no-any-return]
def get_num_frames_video(video: torch.Tensor) -> int:
return video.shape[-4]
def get_num_frames(inpt: features.VideoTypeJIT) -> int:
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features.Video)):
return get_num_frames_video(inpt)
elif isinstance(inpt, features.Video):
return inpt.num_frames
else: else:
return get_spatial_size_image_pil(inpt) raise TypeError(f"The video should be a Tensor. Got {type(inpt)}")
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
...@@ -125,13 +147,13 @@ def convert_format_bounding_box( ...@@ -125,13 +147,13 @@ def convert_format_bounding_box(
def clamp_bounding_box( def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int] bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format. # TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format.
# Not sure if they yield equivalent results. # Not sure if they yield equivalent results.
xyxy_boxes = convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY) xyxy_boxes = convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY)
xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False) return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment