Unverified Commit 5dd95944 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Remove `color_space` metadata and `ConvertColorSpace()` transform (#7120)

parent c206a471
......@@ -238,7 +238,6 @@ class TensorLoader:
@dataclasses.dataclass
class ImageLoader(TensorLoader):
color_space: datapoints.ColorSpace
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
......@@ -248,10 +247,10 @@ class ImageLoader(TensorLoader):
NUM_CHANNELS_MAP = {
datapoints.ColorSpace.GRAY: 1,
datapoints.ColorSpace.GRAY_ALPHA: 2,
datapoints.ColorSpace.RGB: 3,
datapoints.ColorSpace.RGB_ALPHA: 4,
"GRAY": 1,
"GRAY_ALPHA": 2,
"RGB": 3,
"RGBA": 4,
}
......@@ -265,7 +264,7 @@ def get_num_channels(color_space):
def make_image_loader(
size="random",
*,
color_space=datapoints.ColorSpace.RGB,
color_space="RGB",
extra_dims=(),
dtype=torch.float32,
constant_alpha=True,
......@@ -276,11 +275,11 @@ def make_image_loader(
def fn(shape, dtype, device):
max_value = get_max_value(dtype)
data = torch.testing.make_tensor(shape, low=0, high=max_value, dtype=dtype, device=device)
if color_space in {datapoints.ColorSpace.GRAY_ALPHA, datapoints.ColorSpace.RGB_ALPHA} and constant_alpha:
if color_space in {"GRAY_ALPHA", "RGBA"} and constant_alpha:
data[..., -1, :, :] = max_value
return datapoints.Image(data, color_space=color_space)
return datapoints.Image(data)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype, color_space=color_space)
return ImageLoader(fn, shape=(*extra_dims, num_channels, *size), dtype=dtype)
make_image = from_loader(make_image_loader)
......@@ -290,10 +289,10 @@ def make_image_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=(
datapoints.ColorSpace.GRAY,
datapoints.ColorSpace.GRAY_ALPHA,
datapoints.ColorSpace.RGB,
datapoints.ColorSpace.RGB_ALPHA,
"GRAY",
"GRAY_ALPHA",
"RGB",
"RGBA",
),
extra_dims=DEFAULT_EXTRA_DIMS,
dtypes=(torch.float32, torch.uint8),
......@@ -306,7 +305,7 @@ def make_image_loaders(
make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation(size="random", *, color_space=datapoints.ColorSpace.RGB, dtype=torch.uint8):
def make_image_loader_for_interpolation(size="random", *, color_space="RGB", dtype=torch.uint8):
size = _parse_spatial_size(size)
num_channels = get_num_channels(color_space)
......@@ -318,24 +317,24 @@ def make_image_loader_for_interpolation(size="random", *, color_space=datapoints
.resize((width, height))
.convert(
{
datapoints.ColorSpace.GRAY: "L",
datapoints.ColorSpace.GRAY_ALPHA: "LA",
datapoints.ColorSpace.RGB: "RGB",
datapoints.ColorSpace.RGB_ALPHA: "RGBA",
"GRAY": "L",
"GRAY_ALPHA": "LA",
"RGB": "RGB",
"RGBA": "RGBA",
}[color_space]
)
)
image_tensor = convert_dtype_image_tensor(to_image_tensor(image_pil).to(device=device), dtype=dtype)
return datapoints.Image(image_tensor, color_space=color_space)
return datapoints.Image(image_tensor)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, color_space=color_space)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype)
def make_image_loaders_for_interpolation(
sizes=((233, 147),),
color_spaces=(datapoints.ColorSpace.RGB,),
color_spaces=("RGB",),
dtypes=(torch.uint8,),
):
for params in combinations_grid(size=sizes, color_space=color_spaces, dtype=dtypes):
......@@ -583,7 +582,7 @@ class VideoLoader(ImageLoader):
def make_video_loader(
size="random",
*,
color_space=datapoints.ColorSpace.RGB,
color_space="RGB",
num_frames="random",
extra_dims=(),
dtype=torch.uint8,
......@@ -592,12 +591,10 @@ def make_video_loader(
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
def fn(shape, dtype, device):
video = make_image(size=shape[-2:], color_space=color_space, extra_dims=shape[:-3], dtype=dtype, device=device)
return datapoints.Video(video, color_space=color_space)
video = make_image(size=shape[-2:], extra_dims=shape[:-3], dtype=dtype, device=device)
return datapoints.Video(video)
return VideoLoader(
fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype, color_space=color_space
)
return VideoLoader(fn, shape=(*extra_dims, num_frames, get_num_channels(color_space), *size), dtype=dtype)
make_video = from_loader(make_video_loader)
......@@ -607,8 +604,8 @@ def make_video_loaders(
*,
sizes=DEFAULT_SPATIAL_SIZES,
color_spaces=(
datapoints.ColorSpace.GRAY,
datapoints.ColorSpace.RGB,
"GRAY",
"RGB",
),
num_frames=(1, 0, "random"),
extra_dims=DEFAULT_EXTRA_DIMS,
......
......@@ -9,7 +9,6 @@ import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F
from common_utils import cycle_over
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
......@@ -261,14 +260,12 @@ def _get_resize_sizes(spatial_size):
def sample_inputs_resize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]):
for size in _get_resize_sizes(image_loader.spatial_size):
yield ArgsKwargs(image_loader, size=size)
for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB]),
make_image_loaders(sizes=["random"], color_spaces=["RGB"]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
......@@ -472,7 +469,7 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
def sample_inputs_affine_image_tensor():
make_affine_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
)
for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS):
......@@ -684,69 +681,6 @@ KERNEL_INFOS.append(
)
def sample_inputs_convert_color_space_image_tensor():
color_spaces = sorted(
set(datapoints.ColorSpace) - {datapoints.ColorSpace.OTHER}, key=lambda color_space: color_space.value
)
for old_color_space, new_color_space in cycle_over(color_spaces):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=[old_color_space], constant_alpha=True):
yield ArgsKwargs(image_loader, old_color_space=old_color_space, new_color_space=new_color_space)
for color_space in color_spaces:
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True
):
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space)
@pil_reference_wrapper
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space):
color_space_pil = datapoints.ColorSpace.from_pil_mode(image_pil.mode)
if color_space_pil != old_color_space:
raise pytest.UsageError(
f"Converting the tensor image into an PIL image changed the colorspace "
f"from {old_color_space} to {color_space_pil}"
)
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space)
def reference_inputs_convert_color_space_image_tensor():
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs
if len(image_loader.shape) == 3 and image_loader.dtype == torch.uint8:
yield args_kwargs
def sample_inputs_convert_color_space_video():
color_spaces = [datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB]
for old_color_space, new_color_space in cycle_over(color_spaces):
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[old_color_space], num_frames=["random"]):
yield ArgsKwargs(video_loader, old_color_space=old_color_space, new_color_space=new_color_space)
KERNEL_INFOS.extend(
[
KernelInfo(
F.convert_color_space_image_tensor,
sample_inputs_fn=sample_inputs_convert_color_space_image_tensor,
reference_fn=reference_convert_color_space_image_tensor,
reference_inputs_fn=reference_inputs_convert_color_space_image_tensor,
closeness_kwargs={
**pil_reference_pixel_difference(),
**float32_vs_uint8_pixel_difference(),
},
),
KernelInfo(
F.convert_color_space_video,
sample_inputs_fn=sample_inputs_convert_color_space_video,
),
]
)
def sample_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(image_loader)
......@@ -822,7 +756,7 @@ _ROTATE_ANGLES = [-87, 15, 90]
def sample_inputs_rotate_image_tensor():
make_rotate_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
)
for image_loader in make_rotate_image_loaders():
......@@ -904,7 +838,7 @@ _CROP_PARAMS = combinations_grid(top=[-8, 0, 9], left=[-8, 0, 9], height=[12, 20
def sample_inputs_crop_image_tensor():
for image_loader, params in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]),
make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
[
dict(top=4, left=3, height=7, width=8),
dict(top=-1, left=3, height=7, width=8),
......@@ -1090,7 +1024,7 @@ _PAD_PARAMS = combinations_grid(
def sample_inputs_pad_image_tensor():
make_pad_image_loaders = functools.partial(
make_image_loaders, sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]
make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]
)
for image_loader, padding in itertools.product(
......@@ -1406,7 +1340,7 @@ _CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
def sample_inputs_center_crop_image_tensor():
for image_loader, output_size in itertools.product(
make_image_loaders(sizes=[(16, 17)], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]),
make_image_loaders(sizes=[(16, 17)], color_spaces=["RGB"], dtypes=[torch.float32]),
[
# valid `output_size` types for which cropping is applied to both dimensions
*[5, (4,), (2, 3), [6], [3, 2]],
......@@ -1492,9 +1426,7 @@ KERNEL_INFOS.extend(
def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial(
make_image_loaders, sizes=[(7, 33)], color_spaces=[datapoints.ColorSpace.RGB]
)
make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
yield ArgsKwargs(image_loader, kernel_size=kernel_size)
......@@ -1531,9 +1463,7 @@ KERNEL_INFOS.extend(
def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
......@@ -1560,7 +1490,7 @@ def reference_inputs_equalize_image_tensor():
spatial_size = (256, 256)
for dtype, color_space, fn in itertools.product(
[torch.uint8],
[datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB],
["GRAY", "RGB"],
[
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
lambda shape, dtype, device: torch.full(
......@@ -1585,9 +1515,7 @@ def reference_inputs_equalize_image_tensor():
],
],
):
image_loader = ImageLoader(
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space
)
image_loader = ImageLoader(fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype)
yield ArgsKwargs(image_loader)
......@@ -1615,16 +1543,12 @@ KERNEL_INFOS.extend(
def sample_inputs_invert_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
def reference_inputs_invert_image_tensor():
for image_loader in make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
......@@ -1655,17 +1579,13 @@ _POSTERIZE_BITS = [1, 4, 8]
def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_POSTERIZE_BITS,
):
yield ArgsKwargs(image_loader, bits=bits)
......@@ -1702,16 +1622,12 @@ def _get_solarize_thresholds(dtype):
def sample_inputs_solarize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, threshold=next(_get_solarize_thresholds(image_loader.dtype)))
def reference_inputs_solarize_image_tensor():
for image_loader in make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
for threshold in _get_solarize_thresholds(image_loader.dtype):
yield ArgsKwargs(image_loader, threshold=threshold)
......@@ -1745,16 +1661,12 @@ KERNEL_INFOS.extend(
def sample_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
def reference_inputs_autocontrast_image_tensor():
for image_loader in make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
):
for image_loader in make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)
......@@ -1790,16 +1702,14 @@ _ADJUST_SHARPNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_sharpness_image_tensor():
for image_loader in make_image_loaders(
sizes=["random", (2, 2)],
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB),
color_spaces=("GRAY", "RGB"),
):
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
def reference_inputs_adjust_sharpness_image_tensor():
for image_loader, sharpness_factor in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_SHARPNESS_FACTORS,
):
yield ArgsKwargs(image_loader, sharpness_factor=sharpness_factor)
......@@ -1863,17 +1773,13 @@ _ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_brightness_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_BRIGHTNESS_FACTORS,
):
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
......@@ -1907,17 +1813,13 @@ _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_contrast_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, contrast_factor=_ADJUST_CONTRAST_FACTORS[0])
def reference_inputs_adjust_contrast_image_tensor():
for image_loader, contrast_factor in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_CONTRAST_FACTORS,
):
yield ArgsKwargs(image_loader, contrast_factor=contrast_factor)
......@@ -1959,17 +1861,13 @@ _ADJUST_GAMMA_GAMMAS_GAINS = [
def sample_inputs_adjust_gamma_image_tensor():
gamma, gain = _ADJUST_GAMMA_GAMMAS_GAINS[0]
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
def reference_inputs_adjust_gamma_image_tensor():
for image_loader, (gamma, gain) in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_GAMMA_GAMMAS_GAINS,
):
yield ArgsKwargs(image_loader, gamma=gamma, gain=gain)
......@@ -2007,17 +1905,13 @@ _ADJUST_HUE_FACTORS = [-0.1, 0.5]
def sample_inputs_adjust_hue_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, hue_factor=_ADJUST_HUE_FACTORS[0])
def reference_inputs_adjust_hue_image_tensor():
for image_loader, hue_factor in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_HUE_FACTORS,
):
yield ArgsKwargs(image_loader, hue_factor=hue_factor)
......@@ -2053,17 +1947,13 @@ _ADJUST_SATURATION_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_saturation_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB)
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, saturation_factor=_ADJUST_SATURATION_FACTORS[0])
def reference_inputs_adjust_saturation_image_tensor():
for image_loader, saturation_factor in itertools.product(
make_image_loaders(
color_spaces=(datapoints.ColorSpace.GRAY, datapoints.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_SATURATION_FACTORS,
):
yield ArgsKwargs(image_loader, saturation_factor=saturation_factor)
......@@ -2128,7 +2018,7 @@ def sample_inputs_five_crop_image_tensor():
for size in _FIVE_TEN_CROP_SIZES:
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[datapoints.ColorSpace.RGB],
color_spaces=["RGB"],
dtypes=[torch.float32],
):
yield ArgsKwargs(image_loader, size=size)
......@@ -2152,7 +2042,7 @@ def sample_inputs_ten_crop_image_tensor():
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
for image_loader in make_image_loaders(
sizes=[_get_five_ten_crop_spatial_size(size)],
color_spaces=[datapoints.ColorSpace.RGB],
color_spaces=["RGB"],
dtypes=[torch.float32],
):
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
......@@ -2226,7 +2116,7 @@ _NORMALIZE_MEANS_STDS = [
def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32]),
make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS,
):
yield ArgsKwargs(image_loader, mean=mean, std=std)
......@@ -2242,7 +2132,7 @@ def reference_normalize_image_tensor(image, mean, std, inplace=False):
def reference_inputs_normalize_image_tensor():
yield ArgsKwargs(
make_image_loader(size=(32, 32), color_space=datapoints.ColorSpace.RGB, extra_dims=[1]),
make_image_loader(size=(32, 32), color_space="RGB", extra_dims=[1]),
mean=[0.5, 0.5, 0.5],
std=[1.0, 1.0, 1.0],
)
......@@ -2251,7 +2141,7 @@ def reference_inputs_normalize_image_tensor():
def sample_inputs_normalize_video():
mean, std = _NORMALIZE_MEANS_STDS[0]
for video_loader in make_video_loaders(
sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=["random"], dtypes=[torch.float32]
sizes=["random"], color_spaces=["RGB"], num_frames=["random"], dtypes=[torch.float32]
):
yield ArgsKwargs(video_loader, mean=mean, std=std)
......@@ -2285,9 +2175,7 @@ def sample_inputs_convert_dtype_image_tensor():
# conversion cannot be performed safely
continue
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], dtypes=[input_dtype]
):
for image_loader in make_image_loaders(sizes=["random"], color_spaces=["RGB"], dtypes=[input_dtype]):
yield ArgsKwargs(image_loader, dtype=output_dtype)
......@@ -2414,7 +2302,7 @@ def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def reference_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], color_spaces=[datapoints.ColorSpace.RGB], num_frames=[10]):
for video_loader in make_video_loaders(sizes=["random"], color_spaces=["RGB"], num_frames=[10]):
for num_samples in range(1, video_loader.shape[-4] + 1):
yield ArgsKwargs(video_loader, num_samples)
......
......@@ -161,8 +161,8 @@ class TestSmoke:
itertools.chain.from_iterable(
fn(
color_spaces=[
datapoints.ColorSpace.GRAY,
datapoints.ColorSpace.RGB,
"GRAY",
"RGB",
],
dtypes=[torch.uint8],
extra_dims=[(), (4,)],
......@@ -192,7 +192,7 @@ class TestSmoke:
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=[datapoints.ColorSpace.RGB], dtypes=[torch.float32])
fn(color_spaces=["RGB"], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
......@@ -221,45 +221,6 @@ class TestSmoke:
def test_random_resized_crop(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertColorSpace(color_space=new_color_space, old_color_space=old_color_space),
itertools.chain.from_iterable(
[
fn(color_spaces=[old_color_space])
for fn in (
make_images,
make_vanilla_tensor_images,
make_pil_images,
make_videos,
)
]
),
)
for old_color_space, new_color_space in itertools.product(
[
datapoints.ColorSpace.GRAY,
datapoints.ColorSpace.GRAY_ALPHA,
datapoints.ColorSpace.RGB,
datapoints.ColorSpace.RGB_ALPHA,
],
repeat=2,
)
]
)
def test_convert_color_space(self, transform, input):
transform(input)
def test_convert_color_space_unsupported_types(self):
transform = transforms.ConvertColorSpace(
color_space=datapoints.ColorSpace.RGB, old_color_space=datapoints.ColorSpace.GRAY
)
for inpt in [make_bounding_box(format="XYXY"), make_masks()]:
output = transform(inpt)
assert output is inpt
@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
......@@ -1558,7 +1519,7 @@ class TestFixedSizeCrop:
transform = transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [
make_image(size=spatial_size, color_space=datapoints.ColorSpace.RGB),
make_image(size=spatial_size, color_space="RGB"),
make_bounding_box(
format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
),
......
......@@ -31,7 +31,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[datapoints.ColorSpace.RGB], extra_dims=[(4,)])
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
class ConsistencyConfig:
......@@ -138,9 +138,7 @@ CONSISTENCY_CONFIGS = [
],
# Make sure that the product of the height, width and number of channels matches the number of elements in
# `LINEAR_TRANSFORMATION_MEAN`. For example 2 * 6 * 3 == 4 * 3 * 3 == 36.
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=[datapoints.ColorSpace.RGB]
),
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(2, 6), (4, 3)], color_spaces=["RGB"]),
supports_pil=False,
),
ConsistencyConfig(
......@@ -150,9 +148,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(num_output_channels=1),
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(
DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=[datapoints.ColorSpace.RGB, datapoints.ColorSpace.GRAY]
),
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
),
ConsistencyConfig(
prototype_transforms.ConvertDtype,
......@@ -174,10 +170,10 @@ CONSISTENCY_CONFIGS = [
[ArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
datapoints.ColorSpace.GRAY,
datapoints.ColorSpace.GRAY_ALPHA,
datapoints.ColorSpace.RGB,
datapoints.ColorSpace.RGB_ALPHA,
"GRAY",
"GRAY_ALPHA",
"RGB",
"RGBA",
],
extra_dims=[()],
),
......@@ -911,7 +907,7 @@ class TestRefDetTransforms:
size = (600, 800)
num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space=datapoints.ColorSpace.RGB))
pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -921,7 +917,7 @@ class TestRefDetTransforms:
yield (pil_image, target)
tensor_image = torch.Tensor(make_image(size=size, color_space=datapoints.ColorSpace.RGB))
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -931,7 +927,7 @@ class TestRefDetTransforms:
yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB)
datapoint_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_box(spatial_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -1015,7 +1011,7 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns:
datapoint_image = make_image(size=size, color_space=datapoints.ColorSpace.RGB, dtype=image_dtype)
datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(datapoint_image), datapoint_mask)
......
......@@ -340,7 +340,6 @@ class TestDispatchers:
"dispatcher",
[
F.clamp_bounding_box,
F.convert_color_space,
F.get_dimensions,
F.get_image_num_channels,
F.get_image_size,
......
......@@ -11,7 +11,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any
IMAGE = make_image(color_space=datapoints.ColorSpace.RGB)
IMAGE = make_image(color_space="RGB")
BOUNDING_BOX = make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=IMAGE.spatial_size)
MASK = make_detection_mask(size=IMAGE.spatial_size)
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import ColorSpace, Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from __future__ import annotations
import warnings
from typing import Any, List, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT
class ColorSpace(StrEnum):
OTHER = StrEnum.auto()
GRAY = StrEnum.auto()
GRAY_ALPHA = StrEnum.auto()
RGB = StrEnum.auto()
RGB_ALPHA = StrEnum.auto()
@classmethod
def from_pil_mode(cls, mode: str) -> ColorSpace:
if mode == "L":
return cls.GRAY
elif mode == "LA":
return cls.GRAY_ALPHA
elif mode == "RGB":
return cls.RGB
elif mode == "RGBA":
return cls.RGB_ALPHA
else:
return cls.OTHER
@staticmethod
def from_tensor_shape(shape: List[int]) -> ColorSpace:
return _from_tensor_shape(shape)
def _from_tensor_shape(shape: List[int]) -> ColorSpace:
# Needed as a standalone method for JIT
ndim = len(shape)
if ndim < 2:
return ColorSpace.OTHER
elif ndim == 2:
return ColorSpace.GRAY
num_channels = shape[-3]
if num_channels == 1:
return ColorSpace.GRAY
elif num_channels == 2:
return ColorSpace.GRAY_ALPHA
elif num_channels == 3:
return ColorSpace.RGB
elif num_channels == 4:
return ColorSpace.RGB_ALPHA
else:
return ColorSpace.OTHER
class Image(Datapoint):
color_space: ColorSpace
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image:
def _wrap(cls, tensor: torch.Tensor) -> Image:
image = tensor.as_subclass(cls)
image.color_space = color_space
return image
def __new__(
cls,
data: Any,
*,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
......@@ -81,26 +29,14 @@ class Image(Datapoint):
elif tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
if color_space is None:
color_space = ColorSpace.from_tensor_shape(tensor.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
return cls._wrap(tensor, color_space=color_space)
return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Image, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Image:
return cls._wrap(
tensor,
color_space=color_space if color_space is not None else other.color_space,
)
def wrap_like(cls, other: Image, tensor: torch.Tensor) -> Image:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
return self._make_repr()
@property
def spatial_size(self) -> Tuple[int, int]:
......
from __future__ import annotations
import warnings
from typing import Any, List, Optional, Tuple, Union
import torch
from torchvision.transforms.functional import InterpolationMode
from ._datapoint import Datapoint, FillTypeJIT
from ._image import ColorSpace
class Video(Datapoint):
color_space: ColorSpace
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video:
def _wrap(cls, tensor: torch.Tensor) -> Video:
video = tensor.as_subclass(cls)
video.color_space = color_space
return video
def __new__(
cls,
data: Any,
*,
color_space: Optional[Union[ColorSpace, str]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
......@@ -31,28 +25,14 @@ class Video(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4:
raise ValueError
video = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None:
color_space = ColorSpace.from_tensor_shape(video.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace):
raise ValueError
return cls._wrap(tensor, color_space=color_space)
return cls._wrap(tensor)
@classmethod
def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Video:
return cls._wrap(
tensor,
color_space=color_space if color_space is not None else other.color_space,
)
def wrap_like(cls, other: Video, tensor: torch.Tensor) -> Video:
return cls._wrap(tensor)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
return self._make_repr()
@property
def spatial_size(self) -> Tuple[int, int]:
......
......@@ -39,7 +39,7 @@ from ._geometry import (
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import (
GaussianBlur,
Identity,
......
......@@ -28,6 +28,7 @@ class ToTensor(Transform):
return _F.to_tensor(inpt)
# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
class Grayscale(Transform):
_transformed_types = (
datapoints.Image,
......@@ -62,7 +63,7 @@ class Grayscale(Transform):
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type]
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
......@@ -98,5 +99,5 @@ class RandomGrayscale(_RandomApplyTransform):
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.GRAY) # type: ignore[arg-type]
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
from typing import Any, Dict, Optional, Union
import PIL.Image
from typing import Any, Dict, Union
import torch
......@@ -46,35 +44,6 @@ class ConvertDtype(Transform):
ConvertImageDtype = ConvertDtype
class ConvertColorSpace(Transform):
_transformed_types = (
is_simple_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
)
def __init__(
self,
color_space: Union[str, datapoints.ColorSpace],
old_color_space: Optional[Union[str, datapoints.ColorSpace]] = None,
) -> None:
super().__init__()
if isinstance(color_space, str):
color_space = datapoints.ColorSpace.from_str(color_space)
self.color_space = color_space
if isinstance(old_color_space, str):
old_color_space = datapoints.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space
def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)
class ClampBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox,)
......
......@@ -7,10 +7,6 @@ from ._utils import is_simple_tensor # usort: skip
from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
convert_color_space_image_tensor,
convert_color_space_image_pil,
convert_color_space_video,
convert_color_space,
convert_dtype_image_tensor,
convert_dtype,
convert_dtype_video,
......
......@@ -27,13 +27,11 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if torch.jit.is_scripting() or is_simple_tensor(inpt):
old_color_space = datapoints._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else:
old_color_space = None
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None # TODO: remove when un-deprecating
if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(
inpt, (datapoints.Image, datapoints.Video)
):
inpt = inpt.as_subclass(torch.Tensor)
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
......
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import datapoints
from torchvision.prototype.datapoints import BoundingBoxFormat, ColorSpace
from torchvision.prototype.datapoints import BoundingBoxFormat
from torchvision.transforms import functional_pil as _FP
from torchvision.transforms.functional_tensor import _max_value
......@@ -225,29 +225,6 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
if not torch.all(alpha == _max_value(alpha.dtype)):
raise RuntimeError(
"Stripping the alpha channel if it contains values other than the max value is not supported."
)
return image
def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> torch.Tensor:
if alpha is None:
shape = list(image.shape)
shape[-3] = 1
alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device)
return torch.cat((image, alpha), dim=-3)
def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
repeats = [1] * grayscale.ndim
repeats[-3] = 3
return grayscale.repeat(repeats)
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
......@@ -257,107 +234,6 @@ def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
return l_img
def convert_color_space_image_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
if old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(image)
elif old_color_space == ColorSpace.GRAY and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(_gray_to_rgb(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.GRAY:
return _strip_alpha(image)
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB:
return _gray_to_rgb(_strip_alpha(image))
elif old_color_space == ColorSpace.GRAY_ALPHA and new_color_space == ColorSpace.RGB_ALPHA:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_gray_to_rgb(image), alpha)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(image)
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.GRAY_ALPHA:
return _add_alpha(_rgb_to_gray(image))
elif old_color_space == ColorSpace.RGB and new_color_space == ColorSpace.RGB_ALPHA:
return _add_alpha(image)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY:
return _rgb_to_gray(_strip_alpha(image))
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.GRAY_ALPHA:
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
return _add_alpha(_rgb_to_gray(image), alpha)
elif old_color_space == ColorSpace.RGB_ALPHA and new_color_space == ColorSpace.RGB:
return _strip_alpha(image)
else:
raise RuntimeError(f"Conversion from {old_color_space} to {new_color_space} is not supported.")
_COLOR_SPACE_TO_PIL_MODE = {
ColorSpace.GRAY: "L",
ColorSpace.GRAY_ALPHA: "LA",
ColorSpace.RGB: "RGB",
ColorSpace.RGB_ALPHA: "RGBA",
}
@torch.jit.unused
def convert_color_space_image_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
old_mode = image.mode
try:
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
except KeyError:
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
if image.mode == new_mode:
return image
return image.convert(new_mode)
def convert_color_space_video(
video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
return convert_color_space_image_tensor(video, old_color_space=old_color_space, new_color_space=new_color_space)
def convert_color_space(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT],
color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None,
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_color_space)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if old_color_space is None:
raise RuntimeError(
"In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed."
)
return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space)
elif isinstance(inpt, datapoints.Image):
output = convert_color_space_image_tensor(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return datapoints.Image.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, datapoints.Video):
output = convert_color_space_video(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return datapoints.Video.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, PIL.Image.Image):
return convert_color_space_image_pil(inpt, color_space=color_space)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
......
......@@ -1234,6 +1234,9 @@ def affine(
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
# Looks like to_grayscale() is a stand-alone functional that is never called
# from the transform classes. Perhaps it's still here for BC? I can't be
# bothered to dig. Anyway, this can be deprecated as we migrate to V2.
@torch.jit.unused
def to_grayscale(img, num_output_channels=1):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment