Unverified Commit 4d4711d9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Switch to `spatial_size` (#6736)

* Change `image_size` to `spatial_size`

* Fix linter

* Fixing more tests.

* Adding get_num_channels_video and get_spatial_size_* kernels for video, masks and bboxes.

* Refactor get_spatial_size

* Reduce the usage of `query_chw` where possible

* Rename `query_chw` to `query_spatial_size`

* Adding `get_num_frames` dispatcher and kernel.

* Adding jit-scriptability tests
parent 3099e0cc
...@@ -184,13 +184,18 @@ class ArgsKwargs: ...@@ -184,13 +184,18 @@ class ArgsKwargs:
return args, kwargs return args, kwargs
DEFAULT_SQUARE_IMAGE_SIZE = 15 DEFAULT_SQUARE_SPATIAL_SIZE = 15
DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33) DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33)
DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9) DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9)
DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, "random") DEFAULT_SPATIAL_SIZES = (
DEFAULT_LANDSCAPE_SPATIAL_SIZE,
DEFAULT_PORTRAIT_SPATIAL_SIZE,
DEFAULT_SQUARE_SPATIAL_SIZE,
"random",
)
def _parse_image_size(size, *, name="size"): def _parse_spatial_size(size, *, name="size"):
if size == "random": if size == "random":
return tuple(torch.randint(15, 33, (2,)).tolist()) return tuple(torch.randint(15, 33, (2,)).tolist())
elif isinstance(size, int) and size > 0: elif isinstance(size, int) and size > 0:
...@@ -246,11 +251,11 @@ class TensorLoader: ...@@ -246,11 +251,11 @@ class TensorLoader:
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
color_space: features.ColorSpace color_space: features.ColorSpace
image_size: Tuple[int, int] = dataclasses.field(init=False) spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False)
def __post_init__(self): def __post_init__(self):
self.image_size = self.shape[-2:] self.spatial_size = self.shape[-2:]
self.num_channels = self.shape[-3] self.num_channels = self.shape[-3]
...@@ -277,7 +282,7 @@ def make_image_loader( ...@@ -277,7 +282,7 @@ def make_image_loader(
dtype=torch.float32, dtype=torch.float32,
constant_alpha=True, constant_alpha=True,
): ):
size = _parse_image_size(size) size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
def fn(shape, dtype, device): def fn(shape, dtype, device):
...@@ -295,7 +300,7 @@ make_image = from_loader(make_image_loader) ...@@ -295,7 +300,7 @@ make_image = from_loader(make_image_loader)
def make_image_loaders( def make_image_loaders(
*, *,
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
features.ColorSpace.GRAY, features.ColorSpace.GRAY,
features.ColorSpace.GRAY_ALPHA, features.ColorSpace.GRAY_ALPHA,
...@@ -316,7 +321,7 @@ make_images = from_loaders(make_image_loaders) ...@@ -316,7 +321,7 @@ make_images = from_loaders(make_image_loaders)
@dataclasses.dataclass @dataclasses.dataclass
class BoundingBoxLoader(TensorLoader): class BoundingBoxLoader(TensorLoader):
format: features.BoundingBoxFormat format: features.BoundingBoxFormat
image_size: Tuple[int, int] spatial_size: Tuple[int, int]
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
...@@ -331,7 +336,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): ...@@ -331,7 +336,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
).reshape(low.shape) ).reshape(low.shape)
def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = features.BoundingBoxFormat[format]
if format not in { if format not in {
...@@ -341,7 +346,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp ...@@ -341,7 +346,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
}: }:
raise pytest.UsageError(f"Can't make bounding box in format {format}") raise pytest.UsageError(f"Can't make bounding box in format {format}")
image_size = _parse_image_size(image_size, name="image_size") spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")
def fn(shape, dtype, device): def fn(shape, dtype, device):
*extra_dims, num_coordinates = shape *extra_dims, num_coordinates = shape
...@@ -350,10 +355,10 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp ...@@ -350,10 +355,10 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
if any(dim == 0 for dim in extra_dims): if any(dim == 0 for dim in extra_dims):
return features.BoundingBox( return features.BoundingBox(
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, image_size=image_size torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
height, width = image_size height, width = spatial_size
if format == features.BoundingBoxFormat.XYXY: if format == features.BoundingBoxFormat.XYXY:
x1 = torch.randint(0, width // 2, extra_dims) x1 = torch.randint(0, width // 2, extra_dims)
...@@ -375,10 +380,10 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp ...@@ -375,10 +380,10 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
return features.BoundingBox( return features.BoundingBox(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, image_size=image_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
) )
return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, image_size=image_size) return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
make_bounding_box = from_loader(make_bounding_box_loader) make_bounding_box = from_loader(make_bounding_box_loader)
...@@ -388,11 +393,11 @@ def make_bounding_box_loaders( ...@@ -388,11 +393,11 @@ def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(features.BoundingBoxFormat), formats=tuple(features.BoundingBoxFormat),
image_size="random", spatial_size="random",
dtypes=(torch.float32, torch.int64), dtypes=(torch.float32, torch.int64),
): ):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, image_size=image_size) yield make_bounding_box_loader(**params, spatial_size=spatial_size)
make_bounding_boxes = from_loaders(make_bounding_box_loaders) make_bounding_boxes = from_loaders(make_bounding_box_loaders)
...@@ -475,7 +480,7 @@ class MaskLoader(TensorLoader): ...@@ -475,7 +480,7 @@ class MaskLoader(TensorLoader):
def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8): def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_image_size(size) size = _parse_spatial_size(size)
num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects
def fn(shape, dtype, device): def fn(shape, dtype, device):
...@@ -489,7 +494,7 @@ make_detection_mask = from_loader(make_detection_mask_loader) ...@@ -489,7 +494,7 @@ make_detection_mask = from_loader(make_detection_mask_loader)
def make_detection_mask_loaders( def make_detection_mask_loaders(
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, "random"), num_objects=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
...@@ -503,7 +508,7 @@ make_detection_masks = from_loaders(make_detection_mask_loaders) ...@@ -503,7 +508,7 @@ make_detection_masks = from_loaders(make_detection_mask_loaders)
def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8): def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = _parse_image_size(size) size = _parse_spatial_size(size)
num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories
def fn(shape, dtype, device): def fn(shape, dtype, device):
...@@ -518,7 +523,7 @@ make_segmentation_mask = from_loader(make_segmentation_mask_loader) ...@@ -518,7 +523,7 @@ make_segmentation_mask = from_loader(make_segmentation_mask_loader)
def make_segmentation_mask_loaders( def make_segmentation_mask_loaders(
*, *,
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
num_categories=(1, 2, "random"), num_categories=(1, 2, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
...@@ -532,7 +537,7 @@ make_segmentation_masks = from_loaders(make_segmentation_mask_loaders) ...@@ -532,7 +537,7 @@ make_segmentation_masks = from_loaders(make_segmentation_mask_loaders)
def make_mask_loaders( def make_mask_loaders(
*, *,
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
num_objects=(1, 0, "random"), num_objects=(1, 0, "random"),
num_categories=(1, 2, "random"), num_categories=(1, 2, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
...@@ -559,7 +564,7 @@ def make_video_loader( ...@@ -559,7 +564,7 @@ def make_video_loader(
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
): ):
size = _parse_image_size(size) size = _parse_spatial_size(size)
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
def fn(shape, dtype, device): def fn(shape, dtype, device):
...@@ -576,7 +581,7 @@ make_video = from_loader(make_video_loader) ...@@ -576,7 +581,7 @@ make_video = from_loader(make_video_loader)
def make_video_loaders( def make_video_loaders(
*, *,
sizes=DEFAULT_IMAGE_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
features.ColorSpace.GRAY, features.ColorSpace.GRAY,
features.ColorSpace.RGB, features.ColorSpace.RGB,
......
...@@ -145,7 +145,7 @@ def sample_inputs_horizontal_flip_bounding_box(): ...@@ -145,7 +145,7 @@ def sample_inputs_horizontal_flip_bounding_box():
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
) )
...@@ -185,9 +185,9 @@ KERNEL_INFOS.extend( ...@@ -185,9 +185,9 @@ KERNEL_INFOS.extend(
) )
def _get_resize_sizes(image_size): def _get_resize_sizes(spatial_size):
height, width = image_size height, width = spatial_size
length = max(image_size) length = max(spatial_size)
yield length yield length
yield [length] yield [length]
yield (length,) yield (length,)
...@@ -201,7 +201,7 @@ def sample_inputs_resize_image_tensor(): ...@@ -201,7 +201,7 @@ def sample_inputs_resize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
): ):
for size in _get_resize_sizes(image_loader.image_size): for size in _get_resize_sizes(image_loader.spatial_size):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
...@@ -212,7 +212,7 @@ def sample_inputs_resize_image_tensor(): ...@@ -212,7 +212,7 @@ def sample_inputs_resize_image_tensor():
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
): ):
yield ArgsKwargs(image_loader, size=[min(image_loader.image_size) + 1], interpolation=interpolation) yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)
yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25) yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)
...@@ -236,7 +236,7 @@ def reference_inputs_resize_image_tensor(): ...@@ -236,7 +236,7 @@ def reference_inputs_resize_image_tensor():
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
): ):
for size in _get_resize_sizes(image_loader.image_size): for size in _get_resize_sizes(image_loader.spatial_size):
yield ArgsKwargs( yield ArgsKwargs(
image_loader, image_loader,
size=size, size=size,
...@@ -251,8 +251,8 @@ def reference_inputs_resize_image_tensor(): ...@@ -251,8 +251,8 @@ def reference_inputs_resize_image_tensor():
def sample_inputs_resize_bounding_box(): def sample_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
for size in _get_resize_sizes(bounding_box_loader.image_size): for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.image_size) yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
def sample_inputs_resize_mask(): def sample_inputs_resize_mask():
...@@ -394,7 +394,7 @@ def sample_inputs_affine_bounding_box(): ...@@ -394,7 +394,7 @@ def sample_inputs_affine_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size, spatial_size=bounding_box_loader.spatial_size,
**affine_params, **affine_params,
) )
...@@ -422,9 +422,9 @@ def _compute_affine_matrix(angle, translate, scale, shear, center): ...@@ -422,9 +422,9 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix return true_matrix
def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center=None): def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None: if center is None:
center = [s * 0.5 for s in image_size[::-1]] center = [s * 0.5 for s in spatial_size[::-1]]
def transform(bbox): def transform(bbox):
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center) affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
...@@ -473,7 +473,7 @@ def reference_inputs_affine_bounding_box(): ...@@ -473,7 +473,7 @@ def reference_inputs_affine_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size, spatial_size=bounding_box_loader.spatial_size,
**affine_kwargs, **affine_kwargs,
) )
...@@ -650,7 +650,7 @@ def sample_inputs_vertical_flip_bounding_box(): ...@@ -650,7 +650,7 @@ def sample_inputs_vertical_flip_bounding_box():
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
): ):
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
) )
...@@ -729,7 +729,7 @@ def sample_inputs_rotate_bounding_box(): ...@@ -729,7 +729,7 @@ def sample_inputs_rotate_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size, spatial_size=bounding_box_loader.spatial_size,
angle=_ROTATE_ANGLES[0], angle=_ROTATE_ANGLES[0],
) )
...@@ -1001,7 +1001,7 @@ def sample_inputs_pad_bounding_box(): ...@@ -1001,7 +1001,7 @@ def sample_inputs_pad_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size, spatial_size=bounding_box_loader.spatial_size,
padding=padding, padding=padding,
padding_mode="constant", padding_mode="constant",
) )
...@@ -1131,13 +1131,13 @@ KERNEL_INFOS.extend( ...@@ -1131,13 +1131,13 @@ KERNEL_INFOS.extend(
) )
def _get_elastic_displacement(image_size): def _get_elastic_displacement(spatial_size):
return torch.rand(1, *image_size, 2) return torch.rand(1, *spatial_size, 2)
def sample_inputs_elastic_image_tensor(): def sample_inputs_elastic_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]): for image_loader in make_image_loaders(sizes=["random"]):
displacement = _get_elastic_displacement(image_loader.image_size) displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
...@@ -1151,14 +1151,14 @@ def reference_inputs_elastic_image_tensor(): ...@@ -1151,14 +1151,14 @@ def reference_inputs_elastic_image_tensor():
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
): ):
displacement = _get_elastic_displacement(image_loader.image_size) displacement = _get_elastic_displacement(image_loader.spatial_size)
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill)
def sample_inputs_elastic_bounding_box(): def sample_inputs_elastic_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
displacement = _get_elastic_displacement(bounding_box_loader.image_size) displacement = _get_elastic_displacement(bounding_box_loader.spatial_size)
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
...@@ -1212,7 +1212,7 @@ KERNEL_INFOS.extend( ...@@ -1212,7 +1212,7 @@ KERNEL_INFOS.extend(
) )
_CENTER_CROP_IMAGE_SIZES = [(16, 16), (7, 33), (31, 9)] _CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)] _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
...@@ -1231,7 +1231,7 @@ def sample_inputs_center_crop_image_tensor(): ...@@ -1231,7 +1231,7 @@ def sample_inputs_center_crop_image_tensor():
def reference_inputs_center_crop_image_tensor(): def reference_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product( for image_loader, output_size in itertools.product(
make_image_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES
): ):
yield ArgsKwargs(image_loader, output_size=output_size) yield ArgsKwargs(image_loader, output_size=output_size)
...@@ -1241,7 +1241,7 @@ def sample_inputs_center_crop_bounding_box(): ...@@ -1241,7 +1241,7 @@ def sample_inputs_center_crop_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
image_size=bounding_box_loader.image_size, spatial_size=bounding_box_loader.spatial_size,
output_size=output_size, output_size=output_size,
) )
...@@ -1254,7 +1254,7 @@ def sample_inputs_center_crop_mask(): ...@@ -1254,7 +1254,7 @@ def sample_inputs_center_crop_mask():
def reference_inputs_center_crop_mask(): def reference_inputs_center_crop_mask():
for mask_loader, output_size in itertools.product( for mask_loader, output_size in itertools.product(
make_mask_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
): ):
yield ArgsKwargs(mask_loader, output_size=output_size) yield ArgsKwargs(mask_loader, output_size=output_size)
...@@ -1820,7 +1820,7 @@ KERNEL_INFOS.extend( ...@@ -1820,7 +1820,7 @@ KERNEL_INFOS.extend(
def sample_inputs_clamp_bounding_box(): def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
) )
...@@ -1834,7 +1834,7 @@ KERNEL_INFOS.append( ...@@ -1834,7 +1834,7 @@ KERNEL_INFOS.append(
_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]] _FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]
def _get_five_ten_crop_image_size(size): def _get_five_ten_crop_spatial_size(size):
if isinstance(size, int): if isinstance(size, int):
crop_height = crop_width = size crop_height = crop_width = size
elif len(size) == 1: elif len(size) == 1:
...@@ -1847,28 +1847,32 @@ def _get_five_ten_crop_image_size(size): ...@@ -1847,28 +1847,32 @@ def _get_five_ten_crop_image_size(size):
def sample_inputs_five_crop_image_tensor(): def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_image_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[features.ColorSpace.RGB],
dtypes=[torch.float32],
): ):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
def reference_inputs_five_crop_image_tensor(): def reference_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]): for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
def sample_inputs_ten_crop_image_tensor(): def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_image_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[features.ColorSpace.RGB],
dtypes=[torch.float32],
): ):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
def reference_inputs_ten_crop_image_tensor(): def reference_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]): for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
......
...@@ -298,7 +298,7 @@ class TestRandomHorizontalFlip: ...@@ -298,7 +298,7 @@ class TestRandomHorizontalFlip:
assert_equal(features.Mask(expected), actual) assert_equal(features.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(input) actual = transform(input)
...@@ -307,7 +307,7 @@ class TestRandomHorizontalFlip: ...@@ -307,7 +307,7 @@ class TestRandomHorizontalFlip:
expected = features.BoundingBox.wrap_like(input, expected_image_tensor) expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.image_size == expected.image_size assert actual.spatial_size == expected.spatial_size
@pytest.mark.parametrize("p", [0.0, 1.0]) @pytest.mark.parametrize("p", [0.0, 1.0])
...@@ -351,7 +351,7 @@ class TestRandomVerticalFlip: ...@@ -351,7 +351,7 @@ class TestRandomVerticalFlip:
assert_equal(features.Mask(expected), actual) assert_equal(features.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
actual = transform(input) actual = transform(input)
...@@ -360,7 +360,7 @@ class TestRandomVerticalFlip: ...@@ -360,7 +360,7 @@ class TestRandomVerticalFlip:
expected = features.BoundingBox.wrap_like(input, expected_image_tensor) expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.image_size == expected.image_size assert actual.spatial_size == expected.spatial_size
class TestPad: class TestPad:
...@@ -435,7 +435,7 @@ class TestRandomZoomOut: ...@@ -435,7 +435,7 @@ class TestRandomZoomOut:
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
h, w = image.image_size = (24, 32) h, w = image.spatial_size = (24, 32)
params = transform._get_params(image) params = transform._get_params(image)
...@@ -450,7 +450,7 @@ class TestRandomZoomOut: ...@@ -450,7 +450,7 @@ class TestRandomZoomOut:
def test__transform(self, fill, side_range, mocker): def test__transform(self, fill, side_range, mocker):
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.image_size = (24, 32) inpt.spatial_size = (24, 32)
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
...@@ -559,17 +559,17 @@ class TestRandomRotation: ...@@ -559,17 +559,17 @@ class TestRandomRotation:
@pytest.mark.parametrize("angle", [34, -87]) @pytest.mark.parametrize("angle", [34, -87])
@pytest.mark.parametrize("expand", [False, True]) @pytest.mark.parametrize("expand", [False, True])
def test_boundingbox_image_size(self, angle, expand): def test_boundingbox_spatial_size(self, angle, expand):
# Specific test for BoundingBox.rotate # Specific test for BoundingBox.rotate
bbox = features.BoundingBox( bbox = features.BoundingBox(
torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, image_size=(32, 32) torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, spatial_size=(32, 32)
) )
img = features.Image(torch.rand(1, 3, 32, 32)) img = features.Image(torch.rand(1, 3, 32, 32))
out_img = img.rotate(angle, expand=expand) out_img = img.rotate(angle, expand=expand)
out_bbox = bbox.rotate(angle, expand=expand) out_bbox = bbox.rotate(angle, expand=expand)
assert out_img.image_size == out_bbox.image_size assert out_img.spatial_size == out_bbox.spatial_size
class TestRandomAffine: class TestRandomAffine:
...@@ -619,8 +619,8 @@ class TestRandomAffine: ...@@ -619,8 +619,8 @@ class TestRandomAffine:
def test__get_params(self, degrees, translate, scale, shear, mocker): def test__get_params(self, degrees, translate, scale, shear, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
h, w = image.image_size h, w = image.spatial_size
transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear)
params = transform._get_params(image) params = transform._get_params(image)
...@@ -682,7 +682,7 @@ class TestRandomAffine: ...@@ -682,7 +682,7 @@ class TestRandomAffine:
fn = mocker.patch("torchvision.prototype.transforms.functional.affine") fn = mocker.patch("torchvision.prototype.transforms.functional.affine")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.image_size = (24, 32) inpt.spatial_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
...@@ -718,8 +718,8 @@ class TestRandomCrop: ...@@ -718,8 +718,8 @@ class TestRandomCrop:
def test__get_params(self, padding, pad_if_needed, size, mocker): def test__get_params(self, padding, pad_if_needed, size, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
h, w = image.image_size h, w = image.spatial_size
transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed)
params = transform._get_params(image) params = transform._get_params(image)
...@@ -771,19 +771,19 @@ class TestRandomCrop: ...@@ -771,19 +771,19 @@ class TestRandomCrop:
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.image_size = (32, 32) inpt.spatial_size = (32, 32)
expected = mocker.MagicMock(spec=features.Image) expected = mocker.MagicMock(spec=features.Image)
expected.num_channels = 3 expected.num_channels = 3
if isinstance(padding, int): if isinstance(padding, int):
expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding) expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding)
elif isinstance(padding, list): elif isinstance(padding, list):
expected.image_size = ( expected.spatial_size = (
inpt.image_size[0] + sum(padding[0::2]), inpt.spatial_size[0] + sum(padding[0::2]),
inpt.image_size[1] + sum(padding[1::2]), inpt.spatial_size[1] + sum(padding[1::2]),
) )
else: else:
expected.image_size = inpt.image_size expected.spatial_size = inpt.spatial_size
_ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) _ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop")
...@@ -859,7 +859,7 @@ class TestGaussianBlur: ...@@ -859,7 +859,7 @@ class TestGaussianBlur:
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.image_size = (24, 32) inpt.spatial_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
...@@ -910,11 +910,11 @@ class TestRandomPerspective: ...@@ -910,11 +910,11 @@ class TestRandomPerspective:
transform = transforms.RandomPerspective(dscale) transform = transforms.RandomPerspective(dscale)
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
params = transform._get_params(image) params = transform._get_params(image)
h, w = image.image_size h, w = image.spatial_size
assert "perspective_coeffs" in params assert "perspective_coeffs" in params
assert len(params["perspective_coeffs"]) == 8 assert len(params["perspective_coeffs"]) == 8
...@@ -927,7 +927,7 @@ class TestRandomPerspective: ...@@ -927,7 +927,7 @@ class TestRandomPerspective:
fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") fn = mocker.patch("torchvision.prototype.transforms.functional.perspective")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.image_size = (24, 32) inpt.spatial_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
# Otherwise, we can mock transform._get_params # Otherwise, we can mock transform._get_params
...@@ -971,11 +971,11 @@ class TestElasticTransform: ...@@ -971,11 +971,11 @@ class TestElasticTransform:
transform = transforms.ElasticTransform(alpha, sigma) transform = transforms.ElasticTransform(alpha, sigma)
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
params = transform._get_params(image) params = transform._get_params(image)
h, w = image.image_size h, w = image.spatial_size
displacement = params["displacement"] displacement = params["displacement"]
assert displacement.shape == (1, h, w, 2) assert displacement.shape == (1, h, w, 2)
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
...@@ -1001,7 +1001,7 @@ class TestElasticTransform: ...@@ -1001,7 +1001,7 @@ class TestElasticTransform:
fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") fn = mocker.patch("torchvision.prototype.transforms.functional.elastic")
inpt = mocker.MagicMock(spec=features.Image) inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.image_size = (24, 32) inpt.spatial_size = (24, 32)
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock() transform._get_params = mocker.MagicMock()
...@@ -1030,7 +1030,7 @@ class TestRandomErasing: ...@@ -1030,7 +1030,7 @@ class TestRandomErasing:
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
transform = transforms.RandomErasing(value=[1, 2, 3, 4]) transform = transforms.RandomErasing(value=[1, 2, 3, 4])
...@@ -1041,7 +1041,7 @@ class TestRandomErasing: ...@@ -1041,7 +1041,7 @@ class TestRandomErasing:
def test__get_params(self, value, mocker): def test__get_params(self, value, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
transform = transforms.RandomErasing(value=value) transform = transforms.RandomErasing(value=value)
params = transform._get_params(image) params = transform._get_params(image)
...@@ -1057,8 +1057,8 @@ class TestRandomErasing: ...@@ -1057,8 +1057,8 @@ class TestRandomErasing:
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
assert v.shape == (image.num_channels, 1, 1) assert v.shape == (image.num_channels, 1, 1)
assert 0 <= i <= image.image_size[0] - h assert 0 <= i <= image.spatial_size[0] - h
assert 0 <= j <= image.image_size[1] - w assert 0 <= j <= image.spatial_size[1] - w
@pytest.mark.parametrize("p", [0, 1]) @pytest.mark.parametrize("p", [0, 1])
def test__transform(self, mocker, p): def test__transform(self, mocker, p):
...@@ -1222,11 +1222,11 @@ class TestRandomIoUCrop: ...@@ -1222,11 +1222,11 @@ class TestRandomIoUCrop:
def test__get_params(self, device, options, mocker): def test__get_params(self, device, options, mocker):
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3 image.num_channels = 3
image.image_size = (24, 32) image.spatial_size = (24, 32)
bboxes = features.BoundingBox( bboxes = features.BoundingBox(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY", format="XYXY",
image_size=image.image_size, spatial_size=image.spatial_size,
device=device, device=device,
) )
sample = [image, bboxes] sample = [image, bboxes]
...@@ -1245,8 +1245,8 @@ class TestRandomIoUCrop: ...@@ -1245,8 +1245,8 @@ class TestRandomIoUCrop:
assert len(params["is_within_crop_area"]) > 0 assert len(params["is_within_crop_area"]) > 0
assert params["is_within_crop_area"].dtype == torch.bool assert params["is_within_crop_area"].dtype == torch.bool
orig_h = image.image_size[0] orig_h = image.spatial_size[0]
orig_w = image.image_size[1] orig_w = image.spatial_size[1]
assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h)
assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w)
...@@ -1261,7 +1261,7 @@ class TestRandomIoUCrop: ...@@ -1261,7 +1261,7 @@ class TestRandomIoUCrop:
def test__transform_empty_params(self, mocker): def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0]) transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = features.Image(torch.rand(1, 3, 4, 4)) image = features.Image(torch.rand(1, 3, 4, 4))
bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4)) bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4))
label = features.Label(torch.tensor([1])) label = features.Label(torch.tensor([1]))
sample = [image, bboxes, label] sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
...@@ -1281,7 +1281,7 @@ class TestRandomIoUCrop: ...@@ -1281,7 +1281,7 @@ class TestRandomIoUCrop:
transform = transforms.RandomIoUCrop() transform = transforms.RandomIoUCrop()
image = features.Image(torch.rand(3, 32, 24)) image = features.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,))) label = features.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_detection_mask((32, 24), num_objects=6) masks = make_detection_mask((32, 24), num_objects=6)
...@@ -1329,12 +1329,12 @@ class TestRandomIoUCrop: ...@@ -1329,12 +1329,12 @@ class TestRandomIoUCrop:
class TestScaleJitter: class TestScaleJitter:
def test__get_params(self, mocker): def test__get_params(self, mocker):
image_size = (24, 32) spatial_size = (24, 32)
target_size = (16, 12) target_size = (16, 12)
scale_range = (0.5, 1.5) scale_range = (0.5, 1.5)
transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size)
n_samples = 5 n_samples = 5
for _ in range(n_samples): for _ in range(n_samples):
...@@ -1347,11 +1347,11 @@ class TestScaleJitter: ...@@ -1347,11 +1347,11 @@ class TestScaleJitter:
assert isinstance(size, tuple) and len(size) == 2 assert isinstance(size, tuple) and len(size) == 2
height, width = size height, width = size
r_min = min(target_size[1] / image_size[0], target_size[0] / image_size[1]) * scale_range[0] r_min = min(target_size[1] / spatial_size[0], target_size[0] / spatial_size[1]) * scale_range[0]
r_max = min(target_size[1] / image_size[0], target_size[0] / image_size[1]) * scale_range[1] r_max = min(target_size[1] / spatial_size[0], target_size[0] / spatial_size[1]) * scale_range[1]
assert int(image_size[0] * r_min) <= height <= int(image_size[0] * r_max) assert int(spatial_size[0] * r_min) <= height <= int(spatial_size[0] * r_max)
assert int(image_size[1] * r_min) <= width <= int(image_size[1] * r_max) assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)
def test__transform(self, mocker): def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock() interpolation_sentinel = mocker.MagicMock()
...@@ -1379,13 +1379,13 @@ class TestScaleJitter: ...@@ -1379,13 +1379,13 @@ class TestScaleJitter:
class TestRandomShortestSize: class TestRandomShortestSize:
def test__get_params(self, mocker): def test__get_params(self, mocker):
image_size = (3, 10) spatial_size = (3, 10)
min_size = [5, 9] min_size = [5, 9]
max_size = 20 max_size = 20
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size)
params = transform._get_params(sample) params = transform._get_params(sample)
assert "size" in params assert "size" in params
...@@ -1504,7 +1504,7 @@ class TestSimpleCopyPaste: ...@@ -1504,7 +1504,7 @@ class TestSimpleCopyPaste:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": features.BoundingBox( "boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32)
), ),
"masks": features.Mask(masks), "masks": features.Mask(masks),
"labels": label_type(labels), "labels": label_type(labels),
...@@ -1519,7 +1519,7 @@ class TestSimpleCopyPaste: ...@@ -1519,7 +1519,7 @@ class TestSimpleCopyPaste:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": features.BoundingBox( "boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32)
), ),
"masks": features.Mask(paste_masks), "masks": features.Mask(paste_masks),
"labels": label_type(paste_labels), "labels": label_type(paste_labels),
...@@ -1550,14 +1550,14 @@ class TestFixedSizeCrop: ...@@ -1550,14 +1550,14 @@ class TestFixedSizeCrop:
def test__get_params(self, mocker): def test__get_params(self, mocker):
crop_size = (7, 7) crop_size = (7, 7)
batch_shape = (10,) batch_shape = (10,)
image_size = (11, 5) spatial_size = (11, 5)
transform = transforms.FixedSizeCrop(size=crop_size) transform = transforms.FixedSizeCrop(size=crop_size)
sample = dict( sample = dict(
image=make_image(size=image_size, color_space=features.ColorSpace.RGB), image=make_image(size=spatial_size, color_space=features.ColorSpace.RGB),
bounding_boxes=make_bounding_box( bounding_boxes=make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
), ),
) )
params = transform._get_params(sample) params = transform._get_params(sample)
...@@ -1638,7 +1638,7 @@ class TestFixedSizeCrop: ...@@ -1638,7 +1638,7 @@ class TestFixedSizeCrop:
def test__transform_culling(self, mocker): def test__transform_culling(self, mocker):
batch_size = 10 batch_size = 10
image_size = (10, 10) spatial_size = (10, 10)
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool) is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch( mocker.patch(
...@@ -1647,17 +1647,17 @@ class TestFixedSizeCrop: ...@@ -1647,17 +1647,17 @@ class TestFixedSizeCrop:
needs_crop=True, needs_crop=True,
top=0, top=0,
left=0, left=0,
height=image_size[0], height=spatial_size[0],
width=image_size[1], width=spatial_size[1],
is_valid=is_valid, is_valid=is_valid,
needs_pad=False, needs_pad=False,
), ),
) )
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,)
) )
masks = make_detection_mask(size=image_size, extra_dims=(batch_size,)) masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
...@@ -1678,7 +1678,7 @@ class TestFixedSizeCrop: ...@@ -1678,7 +1678,7 @@ class TestFixedSizeCrop:
def test__transform_bounding_box_clamping(self, mocker): def test__transform_bounding_box_clamping(self, mocker):
batch_size = 3 batch_size = 3
image_size = (10, 10) spatial_size = (10, 10)
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
...@@ -1686,15 +1686,15 @@ class TestFixedSizeCrop: ...@@ -1686,15 +1686,15 @@ class TestFixedSizeCrop:
needs_crop=True, needs_crop=True,
top=0, top=0,
left=0, left=0,
height=image_size[0], height=spatial_size[0],
width=image_size[1], width=spatial_size[1],
is_valid=torch.full((batch_size,), fill_value=True), is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False, needs_pad=False,
), ),
) )
bounding_box = make_bounding_box( bounding_box = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=(batch_size,)
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
......
...@@ -24,7 +24,7 @@ from torchvision import transforms as legacy_transforms ...@@ -24,7 +24,7 @@ from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms from torchvision.prototype import features, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms._utils import query_chw from torchvision.prototype.transforms._utils import query_spatial_size
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.prototype.transforms.functional import to_image_pil
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
...@@ -871,7 +871,7 @@ class TestRefDetTransforms: ...@@ -871,7 +871,7 @@ class TestRefDetTransforms:
pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB)) pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -881,7 +881,7 @@ class TestRefDetTransforms: ...@@ -881,7 +881,7 @@ class TestRefDetTransforms:
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB)) tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
target = { target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -891,7 +891,7 @@ class TestRefDetTransforms: ...@@ -891,7 +891,7 @@ class TestRefDetTransforms:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB) feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
target = { target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
} }
if with_mask: if with_mask:
...@@ -949,7 +949,7 @@ class PadIfSmaller(prototype_transforms.Transform): ...@@ -949,7 +949,7 @@ class PadIfSmaller(prototype_transforms.Transform):
self.fill = prototype_transforms._geometry._setup_fill_arg(fill) self.fill = prototype_transforms._geometry._setup_fill_arg(fill)
def _get_params(self, sample): def _get_params(self, sample):
_, height, width = query_chw(sample) height, width = query_spatial_size(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding) needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding) return dict(padding=padding, needs_padding=needs_padding)
......
...@@ -224,11 +224,14 @@ class TestDispatchers: ...@@ -224,11 +224,14 @@ class TestDispatchers:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dispatcher", "dispatcher",
[ [
F.clamp_bounding_box,
F.convert_color_space, F.convert_color_space,
F.convert_image_dtype, F.convert_image_dtype,
F.get_dimensions, F.get_dimensions,
F.get_image_num_channels, F.get_image_num_channels,
F.get_image_size, F.get_image_size,
F.get_num_channels,
F.get_num_frames,
F.get_spatial_size, F.get_spatial_size,
F.rgb_to_grayscale, F.rgb_to_grayscale,
], ],
...@@ -333,16 +336,16 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -333,16 +336,16 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_bounding_box_on_fixed_input(device): def test_correctness_affine_bounding_box_on_fixed_input(device):
# Check transformation against known expected output # Check transformation against known expected output
image_size = (64, 64) spatial_size = (64, 64)
# xyxy format # xyxy format
in_boxes = [ in_boxes = [
[20, 25, 35, 45], [20, 25, 35, 45],
[50, 5, 70, 22], [50, 5, 70, 22],
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
[1, 1, 5, 5], [1, 1, 5, 5],
] ]
in_boxes = features.BoundingBox( in_boxes = features.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device
) )
# Tested parameters # Tested parameters
angle = 63 angle = 63
...@@ -355,9 +358,9 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -355,9 +358,9 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
# from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox # from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox
# expected_bboxes = [] # expected_bboxes = []
# for in_box in in_boxes: # for in_box in in_boxes:
# n_in_box = normalize_bbox(in_box, *image_size) # n_in_box = normalize_bbox(in_box, *spatial_size)
# n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *image_size) # n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *spatial_size)
# out_box = denormalize_bbox(n_out_box, *image_size) # out_box = denormalize_bbox(n_out_box, *spatial_size)
# expected_bboxes.append(out_box) # expected_bboxes.append(out_box)
expected_bboxes = [ expected_bboxes = [
(24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695), (24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695),
...@@ -369,9 +372,9 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -369,9 +372,9 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
output_boxes = F.affine_bounding_box( output_boxes = F.affine_bounding_box(
in_boxes, in_boxes,
in_boxes.format, in_boxes.format,
in_boxes.image_size, in_boxes.spatial_size,
angle, angle,
(dx * image_size[1], dy * image_size[0]), (dx * spatial_size[1], dy * spatial_size[0]),
scale, scale,
shear=(0, 0), shear=(0, 0),
) )
...@@ -406,7 +409,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -406,7 +409,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
affine_matrix = affine_matrix[:2, :] affine_matrix = affine_matrix[:2, :]
height, width = bbox.image_size height, width = bbox.spatial_size
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
) )
...@@ -444,7 +447,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -444,7 +447,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
out_bbox = features.BoundingBox( out_bbox = features.BoundingBox(
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
image_size=(height, width), spatial_size=(height, width),
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
...@@ -455,16 +458,16 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -455,16 +458,16 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
(height, width), (height, width),
) )
image_size = (32, 38) spatial_size = (32, 38)
for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_spatial_size = bboxes.spatial_size
output_bboxes, output_image_size = F.rotate_bounding_box( output_bboxes, output_spatial_size = F.rotate_bounding_box(
bboxes, bboxes,
bboxes_format, bboxes_format,
image_size=bboxes_image_size, spatial_size=bboxes_spatial_size,
angle=angle, angle=angle,
expand=expand, expand=expand,
center=center, center=center,
...@@ -472,38 +475,38 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -472,38 +475,38 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
center_ = center center_ = center
if center_ is None: if center_ is None:
center_ = [s * 0.5 for s in bboxes_image_size[::-1]] center_ = [s * 0.5 for s in bboxes_spatial_size[::-1]]
if bboxes.ndim < 2: if bboxes.ndim < 2:
bboxes = [bboxes] bboxes = [bboxes]
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bbox, expected_image_size = _compute_expected_bbox(bbox, -angle, expand, center_) expected_bbox, expected_spatial_size = _compute_expected_bbox(bbox, -angle, expand, center_)
expected_bboxes.append(expected_bbox) expected_bboxes.append(expected_bbox)
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_image_size, expected_image_size, atol=1, rtol=0) torch.testing.assert_close(output_spatial_size, expected_spatial_size, atol=1, rtol=0)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2 @pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output # Check transformation against known expected output
image_size = (64, 64) spatial_size = (64, 64)
# xyxy format # xyxy format
in_boxes = [ in_boxes = [
[1, 1, 5, 5], [1, 1, 5, 5],
[1, image_size[0] - 6, 5, image_size[0] - 2], [1, spatial_size[0] - 6, 5, spatial_size[0] - 2],
[image_size[1] - 6, image_size[0] - 6, image_size[1] - 2, image_size[0] - 2], [spatial_size[1] - 6, spatial_size[0] - 6, spatial_size[1] - 2, spatial_size[0] - 2],
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10],
] ]
in_boxes = features.BoundingBox( in_boxes = features.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, dtype=torch.float64, device=device
) )
# Tested parameters # Tested parameters
angle = 45 angle = 45
...@@ -535,7 +538,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -535,7 +538,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
output_boxes, _ = F.rotate_bounding_box( output_boxes, _ = F.rotate_bounding_box(
in_boxes, in_boxes,
in_boxes.format, in_boxes.format,
in_boxes.image_size, in_boxes.spatial_size,
angle, angle,
expand=expand, expand=expand,
center=center, center=center,
...@@ -593,11 +596,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -593,11 +596,11 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
[50.0, 5.0, 70.0, 22.0], [50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0], [45.0, 46.0, 56.0, 62.0],
] ]
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device) in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=size, device=device)
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
output_boxes, output_image_size = F.crop_bounding_box( output_boxes, output_spatial_size = F.crop_bounding_box(
in_boxes, in_boxes,
format, format,
top, top,
...@@ -610,7 +613,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -610,7 +613,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_image_size, size) torch.testing.assert_close(output_spatial_size, size)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -658,7 +661,7 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -658,7 +661,7 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
bbox[3] = (bbox[3] - top_) * size_[0] / height_ bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox return bbox
image_size = (100, 100) spatial_size = (100, 100)
# xyxy format # xyxy format
in_boxes = [ in_boxes = [
[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0],
...@@ -670,18 +673,18 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -670,18 +673,18 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
expected_bboxes = torch.tensor(expected_bboxes, device=device) expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = features.BoundingBox( in_boxes = features.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, device=device in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, device=device
) )
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
output_boxes, output_image_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) output_boxes, output_spatial_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_image_size, size) torch.testing.assert_close(output_spatial_size, size)
def _parse_padding(padding): def _parse_padding(padding):
...@@ -718,28 +721,28 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -718,28 +721,28 @@ def test_correctness_pad_bounding_box(device, padding):
bbox = bbox.to(bbox_dtype) bbox = bbox.to(bbox_dtype)
return bbox return bbox
def _compute_expected_image_size(bbox, padding_): def _compute_expected_spatial_size(bbox, padding_):
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
height, width = bbox.image_size height, width = bbox.spatial_size
return height + pad_up + pad_down, width + pad_left + pad_right return height + pad_up + pad_down, width + pad_left + pad_right
for bboxes in make_bounding_boxes(): for bboxes in make_bounding_boxes():
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_spatial_size = bboxes.spatial_size
output_boxes, output_image_size = F.pad_bounding_box( output_boxes, output_spatial_size = F.pad_bounding_box(
bboxes, format=bboxes_format, image_size=bboxes_image_size, padding=padding bboxes, format=bboxes_format, spatial_size=bboxes_spatial_size, padding=padding
) )
torch.testing.assert_close(output_image_size, _compute_expected_image_size(bboxes, padding)) torch.testing.assert_close(output_spatial_size, _compute_expected_spatial_size(bboxes, padding))
if bboxes.ndim < 2 or bboxes.shape[0] == 0: if bboxes.ndim < 2 or bboxes.shape[0] == 0:
bboxes = [bboxes] bboxes = [bboxes]
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding)) expected_bboxes.append(_compute_expected_bbox(bbox, padding))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -807,7 +810,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -807,7 +810,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
out_bbox = features.BoundingBox( out_bbox = features.BoundingBox(
np.array(out_bbox), np.array(out_bbox),
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
image_size=bbox.image_size, spatial_size=bbox.spatial_size,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
...@@ -815,15 +818,15 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -815,15 +818,15 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
) )
image_size = (32, 38) spatial_size = (32, 38)
pcoeffs = _get_perspective_coeffs(startpoints, endpoints) pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(spatial_size=spatial_size, extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_spatial_size = bboxes.spatial_size
output_bboxes = F.perspective_bounding_box( output_bboxes = F.perspective_bounding_box(
bboxes, bboxes,
...@@ -836,7 +839,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -836,7 +839,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
...@@ -853,14 +856,14 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -853,14 +856,14 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_box(device, output_size): def test_correctness_center_crop_bounding_box(device, output_size):
def _compute_expected_bbox(bbox, output_size_): def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format format_ = bbox.format
image_size_ = bbox.image_size spatial_size_ = bbox.spatial_size
bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH) bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
cy = int(round((image_size_[0] - output_size_[0]) * 0.5)) cy = int(round((spatial_size_[0] - output_size_[0]) * 0.5))
cx = int(round((image_size_[1] - output_size_[1]) * 0.5)) cx = int(round((spatial_size_[1] - output_size_[1]) * 0.5))
out_bbox = [ out_bbox = [
bbox[0].item() - cx, bbox[0].item() - cx,
bbox[1].item() - cy, bbox[1].item() - cy,
...@@ -870,7 +873,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -870,7 +873,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
out_bbox = features.BoundingBox( out_bbox = features.BoundingBox(
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYWH, format=features.BoundingBoxFormat.XYWH,
image_size=output_size_, spatial_size=output_size_,
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
...@@ -879,10 +882,10 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -879,10 +882,10 @@ def test_correctness_center_crop_bounding_box(device, output_size):
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_spatial_size = bboxes.spatial_size
output_boxes, output_image_size = F.center_crop_bounding_box( output_boxes, output_spatial_size = F.center_crop_bounding_box(
bboxes, bboxes_format, bboxes_image_size, output_size bboxes, bboxes_format, bboxes_spatial_size, output_size
) )
if bboxes.ndim < 2: if bboxes.ndim < 2:
...@@ -890,7 +893,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -890,7 +893,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_spatial_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
...@@ -898,7 +901,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -898,7 +901,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_image_size, output_size) torch.testing.assert_close(output_spatial_size, output_size)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -926,11 +929,11 @@ def test_correctness_center_crop_mask(device, output_size): ...@@ -926,11 +929,11 @@ def test_correctness_center_crop_mask(device, output_size):
# Copied from test/test_functional_tensor.py # Copied from test/test_functional_tensor.py
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("image_size", ("small", "large")) @pytest.mark.parametrize("spatial_size", ("small", "large"))
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, sigma): def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize, sigma):
fn = F.gaussian_blur_image_tensor fn = F.gaussian_blur_image_tensor
# true_cv2_results = { # true_cv2_results = {
...@@ -950,7 +953,7 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s ...@@ -950,7 +953,7 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p) true_cv2_results = torch.load(p)
if image_size == "small": if spatial_size == "small":
tensor = ( tensor = (
torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
) )
......
...@@ -11,8 +11,8 @@ from torchvision.prototype.transforms.functional import to_image_pil ...@@ -11,8 +11,8 @@ from torchvision.prototype.transforms.functional import to_image_pil
IMAGE = make_image(color_space=features.ColorSpace.RGB) IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size)
MASK = make_detection_mask(size=IMAGE.image_size) MASK = make_detection_mask(size=IMAGE.spatial_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -110,7 +110,9 @@ class Caltech101(Dataset): ...@@ -110,7 +110,9 @@ class Caltech101(Dataset):
image=image, image=image,
ann_path=ann_path, ann_path=ann_path,
bounding_box=BoundingBox( bounding_box=BoundingBox(
ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]],
format="xyxy",
spatial_size=image.spatial_size,
), ),
contour=_Feature(ann["obj_contour"].T), contour=_Feature(ann["obj_contour"].T),
) )
......
...@@ -144,7 +144,7 @@ class CelebA(Dataset): ...@@ -144,7 +144,7 @@ class CelebA(Dataset):
bounding_box=BoundingBox( bounding_box=BoundingBox(
[int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")],
format="xywh", format="xywh",
image_size=image.image_size, spatial_size=image.spatial_size,
), ),
landmarks={ landmarks={
landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"])))
......
...@@ -97,25 +97,29 @@ class Coco(Dataset): ...@@ -97,25 +97,29 @@ class Coco(Dataset):
) )
return [images, meta] return [images, meta]
def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size: Tuple[int, int]) -> torch.Tensor: def _segmentation_to_mask(
self, segmentation: Any, *, is_crowd: bool, spatial_size: Tuple[int, int]
) -> torch.Tensor:
from pycocotools import mask from pycocotools import mask
if is_crowd: if is_crowd:
segmentation = mask.frPyObjects(segmentation, *image_size) segmentation = mask.frPyObjects(segmentation, *spatial_size)
else: else:
segmentation = mask.merge(mask.frPyObjects(segmentation, *image_size)) segmentation = mask.merge(mask.frPyObjects(segmentation, *spatial_size))
return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) return torch.from_numpy(mask.decode(segmentation)).to(torch.bool)
def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]:
image_size = (image_meta["height"], image_meta["width"]) spatial_size = (image_meta["height"], image_meta["width"])
labels = [ann["category_id"] for ann in anns] labels = [ann["category_id"] for ann in anns]
return dict( return dict(
# TODO: create a segmentation feature # TODO: create a segmentation feature
segmentations=_Feature( segmentations=_Feature(
torch.stack( torch.stack(
[ [
self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size) self._segmentation_to_mask(
ann["segmentation"], is_crowd=ann["iscrowd"], spatial_size=spatial_size
)
for ann in anns for ann in anns
] ]
) )
...@@ -125,7 +129,7 @@ class Coco(Dataset): ...@@ -125,7 +129,7 @@ class Coco(Dataset):
bounding_boxes=BoundingBox( bounding_boxes=BoundingBox(
[ann["bbox"] for ann in anns], [ann["bbox"] for ann in anns],
format="xywh", format="xywh",
image_size=image_size, spatial_size=spatial_size,
), ),
labels=Label(labels, categories=self._categories), labels=Label(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
......
...@@ -130,13 +130,13 @@ class CUB200(Dataset): ...@@ -130,13 +130,13 @@ class CUB200(Dataset):
return path.with_suffix(".jpg").name return path.with_suffix(".jpg").name
def _2011_prepare_ann( def _2011_prepare_ann(
self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int] self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], spatial_size: Tuple[int, int]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
_, (bounding_box_data, segmentation_data) = data _, (bounding_box_data, segmentation_data) = data
segmentation_path, segmentation_buffer = segmentation_data segmentation_path, segmentation_buffer = segmentation_data
return dict( return dict(
bounding_box=BoundingBox( bounding_box=BoundingBox(
[float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size [float(part) for part in bounding_box_data[1:]], format="xywh", spatial_size=spatial_size
), ),
segmentation_path=segmentation_path, segmentation_path=segmentation_path,
segmentation=EncodedImage.from_file(segmentation_buffer), segmentation=EncodedImage.from_file(segmentation_buffer),
...@@ -149,7 +149,9 @@ class CUB200(Dataset): ...@@ -149,7 +149,9 @@ class CUB200(Dataset):
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name, data return path.with_suffix(".jpg").name, data
def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]: def _2010_prepare_ann(
self, data: Tuple[str, Tuple[str, BinaryIO]], spatial_size: Tuple[int, int]
) -> Dict[str, Any]:
_, (path, buffer) = data _, (path, buffer) = data
content = read_mat(buffer) content = read_mat(buffer)
return dict( return dict(
...@@ -157,7 +159,7 @@ class CUB200(Dataset): ...@@ -157,7 +159,7 @@ class CUB200(Dataset):
bounding_box=BoundingBox( bounding_box=BoundingBox(
[int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")],
format="xyxy", format="xyxy",
image_size=image_size, spatial_size=spatial_size,
), ),
segmentation=_Feature(content["seg"]), segmentation=_Feature(content["seg"]),
) )
...@@ -175,7 +177,7 @@ class CUB200(Dataset): ...@@ -175,7 +177,7 @@ class CUB200(Dataset):
image = EncodedImage.from_file(buffer) image = EncodedImage.from_file(buffer)
return dict( return dict(
prepare_ann_fn(anns_data, image.image_size), prepare_ann_fn(anns_data, image.spatial_size),
image=image, image=image,
label=Label( label=Label(
int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1, int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1,
......
...@@ -78,7 +78,7 @@ class GTSRB(Dataset): ...@@ -78,7 +78,7 @@ class GTSRB(Dataset):
bounding_box = BoundingBox( bounding_box = BoundingBox(
[int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")],
format="xyxy", format="xyxy",
image_size=(int(csv_info["Height"]), int(csv_info["Width"])), spatial_size=(int(csv_info["Height"]), int(csv_info["Width"])),
) )
return { return {
......
...@@ -89,7 +89,7 @@ class StanfordCars(Dataset): ...@@ -89,7 +89,7 @@ class StanfordCars(Dataset):
path=path, path=path,
image=image, image=image,
label=Label(target[4] - 1, categories=self._categories), label=Label(target[4] - 1, categories=self._categories),
bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size), bounding_box=BoundingBox(target[:4], format="xyxy", spatial_size=image.spatial_size),
) )
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
......
...@@ -108,7 +108,7 @@ class VOC(Dataset): ...@@ -108,7 +108,7 @@ class VOC(Dataset):
for instance in instances for instance in instances
], ],
format="xyxy", format="xyxy",
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), spatial_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
), ),
labels=Label( labels=Label(
[self._categories.index(instance["name"]) for instance in instances], categories=self._categories [self._categories.index(instance["name"]) for instance in instances], categories=self._categories
......
...@@ -17,13 +17,13 @@ class BoundingBoxFormat(StrEnum): ...@@ -17,13 +17,13 @@ class BoundingBoxFormat(StrEnum):
class BoundingBox(_Feature): class BoundingBox(_Feature):
format: BoundingBoxFormat format: BoundingBoxFormat
image_size: Tuple[int, int] spatial_size: Tuple[int, int]
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: Tuple[int, int]) -> BoundingBox: def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox:
bounding_box = tensor.as_subclass(cls) bounding_box = tensor.as_subclass(cls)
bounding_box.format = format bounding_box.format = format
bounding_box.image_size = image_size bounding_box.spatial_size = spatial_size
return bounding_box return bounding_box
def __new__( def __new__(
...@@ -31,7 +31,7 @@ class BoundingBox(_Feature): ...@@ -31,7 +31,7 @@ class BoundingBox(_Feature):
data: Any, data: Any,
*, *,
format: Union[BoundingBoxFormat, str], format: Union[BoundingBoxFormat, str],
image_size: Tuple[int, int], spatial_size: Tuple[int, int],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
...@@ -41,7 +41,7 @@ class BoundingBox(_Feature): ...@@ -41,7 +41,7 @@ class BoundingBox(_Feature):
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper()) format = BoundingBoxFormat.from_str(format.upper())
return cls._wrap(tensor, format=format, image_size=image_size) return cls._wrap(tensor, format=format, spatial_size=spatial_size)
@classmethod @classmethod
def wrap_like( def wrap_like(
...@@ -50,16 +50,16 @@ class BoundingBox(_Feature): ...@@ -50,16 +50,16 @@ class BoundingBox(_Feature):
tensor: torch.Tensor, tensor: torch.Tensor,
*, *,
format: Optional[BoundingBoxFormat] = None, format: Optional[BoundingBoxFormat] = None,
image_size: Optional[Tuple[int, int]] = None, spatial_size: Optional[Tuple[int, int]] = None,
) -> BoundingBox: ) -> BoundingBox:
return cls._wrap( return cls._wrap(
tensor, tensor,
format=format if format is not None else other.format, format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size, spatial_size=spatial_size if spatial_size is not None else other.spatial_size,
) )
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, image_size=self.image_size) return self._make_repr(format=self.format, spatial_size=self.spatial_size)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
if isinstance(format, str): if isinstance(format, str):
...@@ -70,11 +70,11 @@ class BoundingBox(_Feature): ...@@ -70,11 +70,11 @@ class BoundingBox(_Feature):
) )
def horizontal_flip(self) -> BoundingBox: def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) output = self._F.horizontal_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
return BoundingBox.wrap_like(self, output) return BoundingBox.wrap_like(self, output)
def vertical_flip(self) -> BoundingBox: def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) output = self._F.vertical_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
return BoundingBox.wrap_like(self, output) return BoundingBox.wrap_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
...@@ -84,20 +84,22 @@ class BoundingBox(_Feature): ...@@ -84,20 +84,22 @@ class BoundingBox(_Feature):
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size) output, spatial_size = self._F.resize_bounding_box(
return BoundingBox.wrap_like(self, output, image_size=image_size) self, spatial_size=self.spatial_size, size=size, max_size=max_size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, image_size = self._F.crop_bounding_box( output, spatial_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width self, self.format, top=top, left=left, height=height, width=width
) )
return BoundingBox.wrap_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def center_crop(self, output_size: List[int]) -> BoundingBox: def center_crop(self, output_size: List[int]) -> BoundingBox:
output, image_size = self._F.center_crop_bounding_box( output, spatial_size = self._F.center_crop_bounding_box(
self, format=self.format, image_size=self.image_size, output_size=output_size self, format=self.format, spatial_size=self.spatial_size, output_size=output_size
) )
return BoundingBox.wrap_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def resized_crop( def resized_crop(
self, self,
...@@ -109,8 +111,8 @@ class BoundingBox(_Feature): ...@@ -109,8 +111,8 @@ class BoundingBox(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) output, spatial_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
return BoundingBox.wrap_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def pad( def pad(
self, self,
...@@ -118,10 +120,10 @@ class BoundingBox(_Feature): ...@@ -118,10 +120,10 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBox:
output, image_size = self._F.pad_bounding_box( output, spatial_size = self._F.pad_bounding_box(
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode self, format=self.format, spatial_size=self.spatial_size, padding=padding, padding_mode=padding_mode
) )
return BoundingBox.wrap_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def rotate( def rotate(
self, self,
...@@ -131,10 +133,10 @@ class BoundingBox(_Feature): ...@@ -131,10 +133,10 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output, image_size = self._F.rotate_bounding_box( output, spatial_size = self._F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center self, format=self.format, spatial_size=self.spatial_size, angle=angle, expand=expand, center=center
) )
return BoundingBox.wrap_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def affine( def affine(
self, self,
...@@ -149,7 +151,7 @@ class BoundingBox(_Feature): ...@@ -149,7 +151,7 @@ class BoundingBox(_Feature):
output = self._F.affine_bounding_box( output = self._F.affine_bounding_box(
self, self,
self.format, self.format,
self.image_size, self.spatial_size,
angle, angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
......
...@@ -49,12 +49,12 @@ class EncodedData(_Feature): ...@@ -49,12 +49,12 @@ class EncodedData(_Feature):
class EncodedImage(EncodedData): class EncodedImage(EncodedData):
# TODO: Use @functools.cached_property if we can depend on Python 3.8 # TODO: Use @functools.cached_property if we can depend on Python 3.8
@property @property
def image_size(self) -> Tuple[int, int]: def spatial_size(self) -> Tuple[int, int]:
if not hasattr(self, "_image_size"): if not hasattr(self, "_spatial_size"):
with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image:
self._image_size = image.height, image.width self._spatial_size = image.height, image.width
return self._image_size return self._spatial_size
class EncodedVideo(EncodedData): class EncodedVideo(EncodedData):
......
...@@ -105,7 +105,7 @@ class Image(_Feature): ...@@ -105,7 +105,7 @@ class Image(_Feature):
return self._make_repr(color_space=self.color_space) return self._make_repr(color_space=self.color_space)
@property @property
def image_size(self) -> Tuple[int, int]: def spatial_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:])) return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property @property
......
...@@ -33,7 +33,7 @@ class Mask(_Feature): ...@@ -33,7 +33,7 @@ class Mask(_Feature):
return cls._wrap(tensor) return cls._wrap(tensor)
@property @property
def image_size(self) -> Tuple[int, int]: def spatial_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:])) return cast(Tuple[int, int], tuple(self.shape[-2:]))
def horizontal_flip(self) -> Mask: def horizontal_flip(self) -> Mask:
......
...@@ -54,9 +54,8 @@ class Video(_Feature): ...@@ -54,9 +54,8 @@ class Video(_Feature):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space) return self._make_repr(color_space=self.color_space)
# TODO: rename this (and all instances of this term to spatial size)
@property @property
def image_size(self) -> Tuple[int, int]: def spatial_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:])) return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property @property
......
...@@ -11,7 +11,7 @@ from torchvision.prototype import features ...@@ -11,7 +11,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, InterpolationMode from torchvision.prototype.transforms import functional as F, InterpolationMode
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import has_any, query_chw from ._utils import has_any, query_chw, query_spatial_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -153,7 +153,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -153,7 +153,7 @@ class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(())) lam = float(self._dist.sample(()))
_, H, W = query_chw(sample) H, W = query_spatial_size(sample)
r_x = torch.randint(W, ()) r_x = torch.randint(W, ())
r_y = torch.randint(H, ()) r_y = torch.randint(H, ())
......
...@@ -100,7 +100,7 @@ class RandomPhotometricDistort(Transform): ...@@ -100,7 +100,7 @@ class RandomPhotometricDistort(Transform):
self.p = p self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
num_channels, _, _ = query_chw(sample) num_channels, *_ = query_chw(sample)
return dict( return dict(
zip( zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"], ["brightness", "contrast1", "saturation", "hue", "contrast2"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment