Unverified Commit 5dd95944 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove `color_space` metadata and `ConvertColorSpace()` transform (#7120)

parent c206a471
...@@ -238,7 +238,6 @@ class TensorLoader: ...@@ -238,7 +238,6 @@ class TensorLoader:
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
color_space: datapoints.ColorSpace
spatial_size: Tuple[int, int] = dataclasses.field(init=False) spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False)
...@@ -248,10 +247,10 @@ class ImageLoader(TensorLoader): ...@@ -248,10 +247,10 @@ class ImageLoader(TensorLoader):
NUM_CHANNELS_MAP = { NUM_CHANNELS_MAP = {
datapoints.ColorSpace.GRAY: 1, "GRAY": 1,
datapoints.ColorSpace.GRAY_ALPHA: 2, "GRAY_ALPHA": 2,
datapoints.ColorSpace.RGB: 3, "RGB": 3,
datapoints.ColorSpace.RGB_ALPHA: 4, "RGBA": 4,
} }
...@@ -265,7 +264,7 @@ def get_num_channels(color_space): ...@@ -265,7 +264,7 @@ def get_num_channels(color_space):
def make_image_loader( def make_image_loader(
size="random", size="random",
*, *,
color_space=datapoints.ColorSpace.RGB, color_space="RGB",
extra_dims=(), extra_dims=(),
dtype=torch.float32, dtype=torch.float32,
constant_alpha=True, constant_alpha=True,
...@@ -276,11 +275,11 @@ def make_image_loader( ...@@ -276,11 +275,11 @@ def make_image_loader(
def fn(shape, dtype, device): def fn(shape, dtype, device):
max_value = get_max_value(dtype) max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device) data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha: if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value data[..., -1, :, :] = max_value
return datapoints.Image(data, color_space=color_space) return datapoints.Image(data)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space) return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
make_image = from_loader(make_image_loader) make_image = from_loader(make_image_loader)
...@@ -290,10 +289,10 @@ def make_image_loaders( ...@@ -290,10 +289,10 @@ def make_image_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
datapoints.ColorSpace.GRAY, "GRAY",
datapoints.ColorSpace.GRAY_ALPHA, "GRAY_ALPHA",
datapoints.ColorSpace.RGB, "RGB",
datapoints.ColorSpace.RGB_ALPHA, "RGBA",
), ),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8), dtypes=(torch.float32, torch.uint8),
...@@ -306,7 +305,7 @@ def make_image_loaders( ...@@ -306,7 +305,7 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders) make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8): def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
size = _parse_spatial_size(size) size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
...@@ -318,24 +317,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=datapoints ...@@ -318,24 +317,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=datapoints
.resize((width, height)) .resize((width, height))
.convert( .convert(
{ {
datapoints.ColorSpace.GRAY: "L", "GRAY": "L",
datapoints.ColorSpace.GRAY_ALPHA: "LA", "GRAY_ALPHA": "LA",
datapoints.ColorSpace.RGB: "RGB", "RGB": "RGB",
datapoints.ColorSpace.RGB_ALPHA: "RGBA", "RGBA": "RGBA",
}[color_space] }[color_space]
) )
) )
image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype) image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
return datapoints.Image(image_tensor, color_space=color_space) return datapoints.Image(image_tensor)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space) return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
def make_image_loaders_for_interpolation( def make_image_loaders_for_interpolation(
sizes=((233, 147),), sizes=((233, 147),),
color_spaces=(datapoints.ColorSpace.RGB,), color_spaces=("RGB",),
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
): ):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes): for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
...@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader): ...@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
def make_video_loader( def make_video_loader(
size="random", size="random",
*, *,
color_space=datapoints.ColorSpace.RGB, color_space="RGB",
num_frames="random", num_frames="random",
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
...@@ -592,12 +591,10 @@ def make_video_loader( ...@@ -592,12 +591,10 @@ def make_video_loader(
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
def fn(shape, dtype, device): def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device) video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
return datapoints.Video(video, color_space=color_space) return datapoints.Video(video)
return VideoLoader( return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
)
make_video = from_loader(make_video_loader) make_video = from_loader(make_video_loader)
...@@ -607,8 +604,8 @@ def make_video_loaders( ...@@ -607,8 +604,8 @@ def make_video_loaders(
*, *,
sizes=DEFAULT_SPATIAL_SIZES, sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=( color_spaces=(
datapoints.ColorSpace.GRAY, "GRAY",
datapoints.ColorSpace.RGB, "RGB",
), ),
num_frames=(1, 0, "random"), num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
......
...@@ -9,7 +9,6 @@ import pytest ...@@ -9,7 +9,6 @@ import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from common_utils import cycle_over
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
...@@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size): ...@@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size):
def sample_inputs_resize_image_tensor(): def sample_inputs_resize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]):
sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
):
for size in _get_resize_sizes(image_loader.spatial_size): for size in _get_resize_sizes(image_loader.spatial_size):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
for image_loader, interpolation in itertools.product( for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB]), make_image_loaders(sizes=["random"], color_spaces=["RGB"]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
...@@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): ...@@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
def sample_inputs_affine_image_tensor(): def sample_inputs_affine_image_tensor():
make_affine_image_loaders = functools.partial( make_affine_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
) )
for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS):
...@@ -684,69 +681,6 @@ KERNEL_INFOS.append( ...@@ -684,69 +681,6 @@ KERNEL_INFOS.append(
) )
def sample_inputs_convert_color_space_image_tensor():
color_spaces = sorted(
set(datapoints.ColorSpace) - {datapoints.ColorSpace.OTHER}, key=lambda color_space: color_space.value
)
for old_color_space, new_color_space in cycle_over(color_spaces):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True):
yield ArgsKwargs(image_loader, old_color_space=old_color_space, new_color_space=new_color_space)
for color_space in color_spaces:
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True
):
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space)
@pil_reference_wrapper
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space):
color_space_pil = datapoints.ColorSpace.from_pil_mode(image_pil.mode)
if color_space_pil != old_color_space:
raise pytest.UsageError(
f"Converting the tensor image into an PIL image changed the colorspace "
f"from {old_color_space} to {color_space_pil}"
)
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space)
def reference_inputs_convert_color_space_image_tensor():
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 3 and image_loader.dtype == torch.uint8:
yield args_kwargs
def sample_inputs_convert_color_space_video():
color_spaces = [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB]
for old_color_space, new_color_space in cycle_over(color_spaces):
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]):
yield ArgsKwargs(video_loader, old_color_space=old_color_space, new_color_space=new_color_space)
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_color_space_image_tensor,
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.convert_color_space_video,
sample_inputs_fn=sample_inputs_convert_color_space_video,
),
]
)
def sample_inputs_vertical_flip_image_tensor(): def sample_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]): for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -822,7 +756,7 @@ _ROTATE_ANGLES = [-87, 15, 90] ...@@ -822,7 +756,7 @@ _ROTATE_ANGLES = [-87, 15, 90]
def sample_inputs_rotate_image_tensor(): def sample_inputs_rotate_image_tensor():
make_rotate_image_loaders = functools.partial( make_rotate_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
) )
for image_loader in make_rotate_image_loaders(): for image_loader in make_rotate_image_loaders():
...@@ -904,7 +838,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20 ...@@ -904,7 +838,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20
def sample_inputs_crop_image_tensor(): def sample_inputs_crop_image_tensor():
for image_loader, params in itertools.product( for image_loader, params in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
[ [
dict(top=4, left=3, height=7, width=8), dict(top=4, left=3, height=7, width=8),
dict(top=-1, left=3, height=7, width=8), dict(top=-1, left=3, height=7, width=8),
...@@ -1090,7 +1024,7 @@ _PAD_PARAMS = combinations_grid( ...@@ -1090,7 +1024,7 @@ _PAD_PARAMS = combinations_grid(
def sample_inputs_pad_image_tensor(): def sample_inputs_pad_image_tensor():
make_pad_image_loaders = functools.partial( make_pad_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32] make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
) )
for image_loader, padding in itertools.product( for image_loader, padding in itertools.product(
...@@ -1406,7 +1340,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)] ...@@ -1406,7 +1340,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def sample_inputs_center_crop_image_tensor(): def sample_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product( for image_loader, output_size in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
[ [
# valid `output_size` types for which cropping is applied to both dimensions # valid `output_size` types for which cropping is applied to both dimensions
*[5, (4,), (2, 3), [6], [3, 2]], *[5, (4,), (2, 3), [6], [3, 2]],
...@@ -1492,9 +1426,7 @@ KERNEL_INFOS.extend( ...@@ -1492,9 +1426,7 @@ KERNEL_INFOS.extend(
def sample_inputs_gaussian_blur_image_tensor(): def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial( make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
make_image_loaders, sizes=[(7, 33)], color_spaces=[datapoints.ColorSpace.RGB]
)
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]): for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
yield ArgsKwargs(image_loader, kernel_size=kernel_size) yield ArgsKwargs(image_loader, kernel_size=kernel_size)
...@@ -1531,9 +1463,7 @@ KERNEL_INFOS.extend( ...@@ -1531,9 +1463,7 @@ KERNEL_INFOS.extend(
def sample_inputs_equalize_image_tensor(): def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1560,7 +1490,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1560,7 +1490,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size = (256, 256) spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product( for dtype, color_space, fn in itertools.product(
[torch.uint8], [torch.uint8],
[datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB], ["GRAY", "RGB"],
[ [
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device), lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full( lambda shape, dtype, device: torch.full(
...@@ -1585,9 +1515,7 @@ def reference_inputs_equalize_image_tensor(): ...@@ -1585,9 +1515,7 @@ def reference_inputs_equalize_image_tensor():
], ],
], ],
): ):
image_loader = ImageLoader( image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype)
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space
)
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1615,16 +1543,12 @@ KERNEL_INFOS.extend( ...@@ -1615,16 +1543,12 @@ KERNEL_INFOS.extend(
def sample_inputs_invert_image_tensor(): def sample_inputs_invert_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
def reference_inputs_invert_image_tensor(): def reference_inputs_invert_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1655,17 +1579,13 @@ _POSTERIZE_BITS = [1, 4, 8] ...@@ -1655,17 +1579,13 @@ _POSTERIZE_BITS = [1, 4, 8]
def sample_inputs_posterize_image_tensor(): def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
def reference_inputs_posterize_image_tensor(): def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product( for image_loader, bits in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_POSTERIZE_BITS, _POSTERIZE_BITS,
): ):
yield ArgsKwargs(image_loader, bits=bits) yield ArgsKwargs(image_loader, bits=bits)
...@@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype): ...@@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype):
def sample_inputs_solarize_image_tensor(): def sample_inputs_solarize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype))) yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))
def reference_inputs_solarize_image_tensor(): def reference_inputs_solarize_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
for threshold in _get_solarize_thresholds(image_loader.dtype): for threshold in _get_solarize_thresholds(image_loader.dtype):
yield ArgsKwargs(image_loader, threshold=threshold) yield ArgsKwargs(image_loader, threshold=threshold)
...@@ -1745,16 +1661,12 @@ KERNEL_INFOS.extend( ...@@ -1745,16 +1661,12 @@ KERNEL_INFOS.extend(
def sample_inputs_autocontrast_image_tensor(): def sample_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
def reference_inputs_autocontrast_image_tensor(): def reference_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
yield ArgsKwargs(image_loader) yield ArgsKwargs(image_loader)
...@@ -1790,16 +1702,14 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5] ...@@ -1790,16 +1702,14 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_sharpness_image_tensor(): def sample_inputs_adjust_sharpness_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random", (2, 2)], sizes=["random", (2, 2)],
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), color_spaces=("GRAY", "RGB"),
): ):
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0]) yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
def reference_inputs_adjust_sharpness_image_tensor(): def reference_inputs_adjust_sharpness_image_tensor():
for image_loader, sharpness_factor in itertools.product( for image_loader, sharpness_factor in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_SHARPNESS_FACTORS, _ADJUST_SHARPNESS_FACTORS,
): ):
yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor) yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)
...@@ -1863,17 +1773,13 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5] ...@@ -1863,17 +1773,13 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_brightness_image_tensor(): def sample_inputs_adjust_brightness_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0]) yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
def reference_inputs_adjust_brightness_image_tensor(): def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product( for image_loader, brightness_factor in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_BRIGHTNESS_FACTORS, _ADJUST_BRIGHTNESS_FACTORS,
): ):
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor) yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
...@@ -1907,17 +1813,13 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5] ...@@ -1907,17 +1813,13 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_contrast_image_tensor(): def sample_inputs_adjust_contrast_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0]) yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
def reference_inputs_adjust_contrast_image_tensor(): def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product( for image_loader, contrast_factor in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_CONTRAST_FACTORS, _ADJUST_CONTRAST_FACTORS,
): ):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor) yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
...@@ -1959,17 +1861,13 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [ ...@@ -1959,17 +1861,13 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
def sample_inputs_adjust_gamma_image_tensor(): def sample_inputs_adjust_gamma_image_tensor():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0] gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
def reference_inputs_adjust_gamma_image_tensor(): def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product( for image_loader, (gamma, gain) in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_GAMMA_GAMMAS_GAINS, _ADJUST_GAMMA_GAMMAS_GAINS,
): ):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain) yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
...@@ -2007,17 +1905,13 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5] ...@@ -2007,17 +1905,13 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
def sample_inputs_adjust_hue_image_tensor(): def sample_inputs_adjust_hue_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0]) yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
def reference_inputs_adjust_hue_image_tensor(): def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product( for image_loader, hue_factor in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_HUE_FACTORS, _ADJUST_HUE_FACTORS,
): ):
yield ArgsKwargs(image_loader, hue_factor=hue_factor) yield ArgsKwargs(image_loader, hue_factor=hue_factor)
...@@ -2053,17 +1947,13 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5] ...@@ -2053,17 +1947,13 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_saturation_image_tensor(): def sample_inputs_adjust_saturation_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0]) yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
def reference_inputs_adjust_saturation_image_tensor(): def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product( for image_loader, saturation_factor in itertools.product(
make_image_loaders( make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
_ADJUST_SATURATION_FACTORS, _ADJUST_SATURATION_FACTORS,
): ):
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor) yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
...@@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor(): ...@@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES: for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[datapoints.ColorSpace.RGB], color_spaces=["RGB"],
dtypes=[torch.float32], dtypes=[torch.float32],
): ):
yield ArgsKwargs(image_loader, size=size) yield ArgsKwargs(image_loader, size=size)
...@@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor(): ...@@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]): for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)], sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[datapoints.ColorSpace.RGB], color_spaces=["RGB"],
dtypes=[torch.float32], dtypes=[torch.float32],
): ):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip) yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
...@@ -2226,7 +2116,7 @@ _NORMALIZE_MEANS_STDS = [ ...@@ -2226,7 +2116,7 @@ _NORMALIZE_MEANS_STDS = [
def sample_inputs_normalize_image_tensor(): def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product( for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]), make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS, _NORMALIZE_MEANS_STDS,
): ):
yield ArgsKwargs(image_loader, mean=mean, std=std) yield ArgsKwargs(image_loader, mean=mean, std=std)
...@@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False): ...@@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False):
def reference_inputs_normalize_image_tensor(): def reference_inputs_normalize_image_tensor():
yield ArgsKwargs( yield ArgsKwargs(
make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]), make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
mean=[0.5, 0.5, 0.5], mean=[0.5, 0.5, 0.5],
std=[1.0, 1.0, 1.0], std=[1.0, 1.0, 1.0],
) )
...@@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor(): ...@@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor():
def sample_inputs_normalize_video(): def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0] mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders( for video_loader in make_video_loaders(
sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32] sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32]
): ):
yield ArgsKwargs(video_loader, mean=mean, std=std) yield ArgsKwargs(video_loader, mean=mean, std=std)
...@@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor(): ...@@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor():
# conversion cannot be performed safely # conversion cannot be performed safely
continue continue
for image_loader in make_image_loaders( for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]):
sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[input_dtype]
):
yield ArgsKwargs(image_loader, dtype=output_dtype) yield ArgsKwargs(image_loader, dtype=output_dtype)
...@@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): ...@@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def reference_inputs_uniform_temporal_subsample_video(): def reference_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=[10]): for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]):
for num_samples in range(1, video_loader.shape[-4] + 1): for num_samples in range(1, video_loader.shape[-4] + 1):
yield ArgsKwargs(video_loader, num_samples) yield ArgsKwargs(video_loader, num_samples)
......
...@@ -161,8 +161,8 @@ class TestSmoke: ...@@ -161,8 +161,8 @@ class TestSmoke:
itertools.chain.from_iterable( itertools.chain.from_iterable(
fn( fn(
color_spaces=[ color_spaces=[
datapoints.ColorSpace.GRAY, "GRAY",
datapoints.ColorSpace.RGB, "RGB",
], ],
dtypes=[torch.uint8], dtypes=[torch.uint8],
extra_dims=[(), (4,)], extra_dims=[(), (4,)],
...@@ -192,7 +192,7 @@ class TestSmoke: ...@@ -192,7 +192,7 @@ class TestSmoke:
( (
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable( itertools.chain.from_iterable(
fn(color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]) fn(color_spaces=["RGB"], dtypes=[torch.float32])
for fn in [ for fn in [
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
...@@ -221,45 +221,6 @@ class TestSmoke: ...@@ -221,45 +221,6 @@ class TestSmoke:
def test_random_resized_crop(self, transform, input): def test_random_resized_crop(self, transform, input):
transform(input) transform(input)
@parametrize(
[
(
transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space),
itertools.chain.from_iterable(
[
fn(color_spaces=[old_color_space])
for fn in (
make_images,
make_vanilla_tensor_images,
make_pil_images,
make_videos,
)
]
),
)
for old_color_space, new_color_space in itertools.product(
[
datapoints.ColorSpace.GRAY,
datapoints.ColorSpace.GRAY_ALPHA,
datapoints.ColorSpace.RGB,
datapoints.ColorSpace.RGB_ALPHA,
],
repeat=2,
)
]
)
def test_convert_color_space(self, transform, input):
transform(input)
def test_convert_color_space_unsupported_types(self):
transform = transforms.ConvertColorSpace(
color_space=datapoints.ColorSpace.RGB, old_color_space=datapoints.ColorSpace.GRAY
)
for inpt in [make_bounding_box(format="XYXY"), make_masks()]:
output = transform(inpt)
assert output is inpt
@pytest.mark.parametrize("p", [0.0, 1.0]) @pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip: class TestRandomHorizontalFlip:
...@@ -1558,7 +1519,7 @@ class TestFixedSizeCrop: ...@@ -1558,7 +1519,7 @@ class TestFixedSizeCrop:
transform = transforms.FixedSizeCrop(size=crop_size) transform = transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [ flat_inputs = [
make_image(size=spatial_size, color_space=datapoints.ColorSpace.RGB), make_image(size=spatial_size, color_space="RGB"),
make_bounding_box( make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
), ),
......
...@@ -31,7 +31,7 @@ from torchvision.prototype.transforms.functional import to_image_pil ...@@ -31,7 +31,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
class ConsistencyConfig: class ConsistencyConfig:
...@@ -138,9 +138,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -138,9 +138,7 @@ CONSISTENCY_CONFIGS = [
], ],
# Make sure that the product of the height, width and number of channels matches the number of elements in # Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36. # `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]),
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB]
),
supports_pil=False, supports_pil=False,
), ),
ConsistencyConfig( ConsistencyConfig(
...@@ -150,9 +148,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -150,9 +148,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=1),
ArgsKwargs(num_output_channels=3), ArgsKwargs(num_output_channels=3),
], ],
make_images_kwargs=dict( make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY]
),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ConvertDtype, prototype_transforms.ConvertDtype,
...@@ -174,10 +170,10 @@ CONSISTENCY_CONFIGS = [ ...@@ -174,10 +170,10 @@ CONSISTENCY_CONFIGS = [
[ArgsKwargs()], [ArgsKwargs()],
make_images_kwargs=dict( make_images_kwargs=dict(
color_spaces=[ color_spaces=[
datapoints.ColorSpace.GRAY, "GRAY",
datapoints.ColorSpace.GRAY_ALPHA, "GRAY_ALPHA",
datapoints.ColorSpace.RGB, "RGB",
datapoints.ColorSpace.RGB_ALPHA, "RGBA",
], ],
extra_dims=[()], extra_dims=[()],
), ),
...@@ -911,7 +907,7 @@ class TestRefDetTransforms: ...@@ -911,7 +907,7 @@ class TestRefDetTransforms:
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -921,7 +917,7 @@ class TestRefDetTransforms: ...@@ -921,7 +917,7 @@ class TestRefDetTransforms:
yield (pil_image, target) yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB)) tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -931,7 +927,7 @@ class TestRefDetTransforms: ...@@ -931,7 +927,7 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB) datapoint_image = make_image(size=size, color_space="RGB")
target = { target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -1015,7 +1011,7 @@ class TestRefSegTransforms: ...@@ -1015,7 +1011,7 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype) datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(datapoint_image), datapoint_mask) dp = (conv_fn(datapoint_image), datapoint_mask)
......
...@@ -340,7 +340,6 @@ class TestDispatchers: ...@@ -340,7 +340,6 @@ class TestDispatchers:
"dispatcher", "dispatcher",
[ [
F.clamp_bounding_box, F.clamp_bounding_box,
F.convert_color_space,
F.get_dimensions, F.get_dimensions,
F.get_image_num_channels, F.get_image_num_channels,
F.get_image_size, F.get_image_size,
......
...@@ -11,7 +11,7 @@ from torchvision.prototype.transforms.functional import to_image_pil ...@@ -11,7 +11,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any from torchvision.prototype.transforms.utils import has_all, has_any
IMAGE = make_image(color_space=datapoints.ColorSpace.RGB) IMAGE = make_image(color_space="RGB")
BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size) BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size)
MASK = make_detection_mask(size=IMAGE.spatial_size) MASK = make_detection_mask(size=IMAGE.spatial_size)
......
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from __future__ import annotations from __future__ import annotations
import warnings
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
class ColorSpace(StrEnum):
OTHER = StrEnum.auto()
GRAY = StrEnum.auto()
GRAY_ALPHA = StrEnum.auto()
RGB = StrEnum.auto()
RGB_ALPHA = StrEnum.auto()
@classmethod
def from_pil_mode(cls, mode: str) -> ColorSpace:
if mode == "L":
return cls.GRAY
elif mode == "LA":
return cls.GRAY_ALPHA
elif mode == "RGB":
return cls.RGB
elif mode == "RGBA":
return cls.RGB_ALPHA
else:
return cls.OTHER
@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
return _from_tensor_shape(shape)
def _from_tensor_shape(shape: List[int]) -> ColorSpace:
# Needed as a standalone method for JIT
ndim = len(shape)
if ndim < 2:
return ColorSpace.OTHER
elif ndim == 2:
return ColorSpace.GRAY
num_channels = shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
class Image(Datapoint): class Image(Datapoint):
color_space: ColorSpace
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image: def _wrap(cls, tensor: torch.Tensor) -> Image:
image = tensor.as_subclass(cls) image = tensor.as_subclass(cls)
image.color_space = color_space
return image return image
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
*, *,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
...@@ -81,26 +29,14 @@ class Image(Datapoint): ...@@ -81,26 +29,14 @@ class Image(Datapoint):
elif tensor.ndim == 2: elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0) tensor = tensor.unsqueeze(0)
if color_space is None: return cls._wrap(tensor)
color_space = ColorSpace.from_tensor_shape(tensor.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
return cls._wrap(tensor, color_space=color_space)
@classmethod @classmethod
def wrap_like(cls, other: Image, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Image: def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
return cls._wrap( return cls._wrap(tensor)
tensor,
color_space=color_space if color_space is not None else other.color_space,
)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space) return self._make_repr()
@property @property
def spatial_size(self) -> Tuple[int, int]: def spatial_size(self) -> Tuple[int, int]:
......
from __future__ import annotations from __future__ import annotations
import warnings
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT from ._datapoint import Datapoint, FillTypeJIT
from ._image import ColorSpace
class Video(Datapoint): class Video(Datapoint):
color_space: ColorSpace
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video: def _wrap(cls, tensor: torch.Tensor) -> Video:
video = tensor.as_subclass(cls) video = tensor.as_subclass(cls)
video.color_space = color_space
return video return video
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
*, *,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
...@@ -31,28 +25,14 @@ class Video(Datapoint): ...@@ -31,28 +25,14 @@ class Video(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4: if data.ndim < 4:
raise ValueError raise ValueError
video = super().__new__(cls, data, requires_grad=requires_grad) return cls._wrap(tensor)
if color_space is None:
color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
return cls._wrap(tensor, color_space=color_space)
@classmethod @classmethod
def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Video: def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
return cls._wrap( return cls._wrap(tensor)
tensor,
color_space=color_space if color_space is not None else other.color_space,
)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space) return self._make_repr()
@property @property
def spatial_size(self) -> Tuple[int, int]: def spatial_size(self) -> Tuple[int, int]:
......
...@@ -39,7 +39,7 @@ from ._geometry import ( ...@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter, ScaleJitter,
TenCrop, TenCrop,
) )
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import ( from ._misc import (
GaussianBlur, GaussianBlur,
Identity, Identity,
......
...@@ -28,6 +28,7 @@ class ToTensor(Transform): ...@@ -28,6 +28,7 @@ class ToTensor(Transform):
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
class Grayscale(Transform): class Grayscale(Transform):
_transformed_types = ( _transformed_types = (
datapoints.Image, datapoints.Image,
...@@ -62,7 +63,7 @@ class Grayscale(Transform): ...@@ -62,7 +63,7 @@ class Grayscale(Transform):
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
...@@ -98,5 +99,5 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -98,5 +99,5 @@ class RandomGrayscale(_RandomApplyTransform):
) -> Union[datapoints.ImageType, datapoints.VideoType]: ) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (datapoints.Image, datapoints.Video)): if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Union
import PIL.Image
import torch import torch
...@@ -46,35 +44,6 @@ class ConvertDtype(Transform): ...@@ -46,35 +44,6 @@ class ConvertDtype(Transform):
ConvertImageDtype = ConvertDtype ConvertImageDtype = ConvertDtype
class ConvertColorSpace(Transform):
_transformed_types = (
is_simple_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
)
def __init__(
self,
color_space: Union[str, datapoints.ColorSpace],
old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None,
) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = datapoints.ColorSpace.from_str(color_space)
self.color_space = color_space
if isinstance(old_color_space, str):
old_color_space = datapoints.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)
class ClampBoundingBoxes(Transform): class ClampBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox,) _transformed_types = (datapoints.BoundingBox,)
......
...@@ -7,10 +7,6 @@ from ._utils import is_simple_tensor # usort: skip ...@@ -7,10 +7,6 @@ from ._utils import is_simple_tensor # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_box, clamp_bounding_box,
convert_format_bounding_box, convert_format_bounding_box,
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
convert_dtype_image_tensor, convert_dtype_image_tensor,
convert_dtype, convert_dtype,
convert_dtype_video, convert_dtype_video,
......
...@@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale( def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if torch.jit.is_scripting() or is_simple_tensor(inpt): old_color_space = None # TODO: remove when un-deprecating
old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(
else: inpt, (datapoints.Image, datapoints.Video)
old_color_space = None ):
inpt = inpt.as_subclass(torch.Tensor)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = ( replacement = (
......
from typing import List, Optional, Tuple, Union from typing import List, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace from torchvision.prototype.datapoints import BoundingBoxFormat
from torchvision.transforms import functional_pil as _FP from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value from torchvision.transforms.functional_tensor import _max_value
...@@ -225,29 +225,6 @@ def clamp_bounding_box( ...@@ -225,29 +225,6 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
return image
def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor:
if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)
def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3) r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
...@@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: ...@@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
return l_img return l_img
def convert_color_space_image_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(_gray_to_rgb(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY:
return _strip_alpha(image)
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(_rgb_to_gray(image))
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
else:
raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.")
_COLOR_SPACE_TO_PIL_MODE = {
ColorSpace.GRAY: "L",
ColorSpace.GRAY_ALPHA: "LA",
ColorSpace.RGB: "RGB",
ColorSpace.RGB_ALPHA: "RGBA",
}
@torch.jit.unused
def convert_color_space_image_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
old_mode = image.mode
try:
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
except KeyError:
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
if image.mode == new_mode:
return image
return image.convert(new_mode)
def convert_color_space_video(
video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
return convert_color_space_image_tensor(video, old_color_space=old_color_space, new_color_space=new_color_space)
def convert_color_space(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_color_space)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space)
elif isinstance(inpt, datapoints.Image):
output = convert_color_space_image_tensor(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return datapoints.Image.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, datapoints.Video):
output = convert_color_space_video(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return datapoints.Video.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space=color_space)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int: def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8: if dtype == torch.uint8:
return 8 return 8
......
...@@ -1234,6 +1234,9 @@ def affine( ...@@ -1234,6 +1234,9 @@ def affine(
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
# bothered to dig. Anyway, this can be deprecated as we migrate to V2.
@torch.jit.unused @torch.jit.unused
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment