Unverified Commit ca012d39 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make PIL kernels private (#7831)

parent cdbbd666
...@@ -228,12 +228,11 @@ Conversion ...@@ -228,12 +228,11 @@ Conversion
ToPILImage ToPILImage
v2.ToPILImage v2.ToPILImage
v2.ToImagePIL
ToTensor ToTensor
v2.ToTensor v2.ToTensor
PILToTensor PILToTensor
v2.PILToTensor v2.PILToTensor
v2.ToImageTensor v2.ToImage
ConvertImageDtype ConvertImageDtype
v2.ConvertImageDtype v2.ConvertImageDtype
v2.ToDtype v2.ToDtype
......
...@@ -27,7 +27,7 @@ def show(sample): ...@@ -27,7 +27,7 @@ def show(sample):
image, target = sample image, target = sample
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image) image = F.to_image(image)
image = F.to_dtype(image, torch.uint8, scale=True) image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
...@@ -101,7 +101,7 @@ transform = transforms.Compose( ...@@ -101,7 +101,7 @@ transform = transforms.Compose(
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}), transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(), transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(), transforms.ToImage(),
transforms.ConvertImageDtype(torch.float32), transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBoxes(), transforms.SanitizeBoundingBoxes(),
] ]
......
...@@ -33,7 +33,7 @@ class DetectionPresetTrain: ...@@ -33,7 +33,7 @@ class DetectionPresetTrain:
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "datapoint": if backend == "datapoint":
transforms.append(T.ToImageTensor()) transforms.append(T.ToImage())
elif backend == "tensor": elif backend == "tensor":
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
...@@ -71,7 +71,7 @@ class DetectionPresetTrain: ...@@ -71,7 +71,7 @@ class DetectionPresetTrain:
if backend == "pil": if backend == "pil":
# Note: we could just convert to pure tensors even in v2. # Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [T.ConvertImageDtype(torch.float)] transforms += [T.ConvertImageDtype(torch.float)]
...@@ -94,11 +94,11 @@ class DetectionPresetEval: ...@@ -94,11 +94,11 @@ class DetectionPresetEval:
backend = backend.lower() backend = backend.lower()
if backend == "pil": if backend == "pil":
# Note: we could just convert to pure tensors even in v2? # Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
elif backend == "tensor": elif backend == "tensor":
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
elif backend == "datapoint": elif backend == "datapoint":
transforms += [T.ToImageTensor()] transforms += [T.ToImage()]
else: else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
......
...@@ -32,7 +32,7 @@ class SegmentationPresetTrain: ...@@ -32,7 +32,7 @@ class SegmentationPresetTrain:
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "datapoint": if backend == "datapoint":
transforms.append(T.ToImageTensor()) transforms.append(T.ToImage())
elif backend == "tensor": elif backend == "tensor":
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
...@@ -81,7 +81,7 @@ class SegmentationPresetEval: ...@@ -81,7 +81,7 @@ class SegmentationPresetEval:
if backend == "tensor": if backend == "tensor":
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
elif backend == "datapoint": elif backend == "datapoint":
transforms += [T.ToImageTensor()] transforms += [T.ToImage()]
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
...@@ -92,7 +92,7 @@ class SegmentationPresetEval: ...@@ -92,7 +92,7 @@ class SegmentationPresetEval:
if backend == "pil": if backend == "pil":
# Note: we could just convert to pure tensors even in v2? # Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [ transforms += [
T.ConvertImageDtype(torch.float), T.ConvertImageDtype(torch.float),
......
...@@ -27,7 +27,7 @@ from PIL import Image ...@@ -27,7 +27,7 @@ from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
...@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair): ...@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair):
**other_parameters, **other_parameters,
): ):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]] actual, expected = [to_image(input) for input in [actual, expected]]
super().__init__(actual, expected, **other_parameters) super().__init__(actual, expected, **other_parameters)
self.mae = mae self.mae = mae
...@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs): ...@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs):
def make_image_pil(*args, **kwargs): def make_image_pil(*args, **kwargs):
return to_image_pil(make_image(*args, **kwargs)) return to_pil_image(make_image(*args, **kwargs))
def make_image_loader( def make_image_loader(
...@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation( ...@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation(
) )
) )
image_tensor = to_image_tensor(image_pil) image_tensor = to_image(image_pil)
if memory_format == torch.contiguous_format: if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else: else:
image_tensor = image_tensor.to(device=device) image_tensor = image_tensor.to(device=device)
image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True) image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True)
return datapoints.Image(image_tensor) return datapoints.Image(image_tensor)
......
...@@ -17,7 +17,7 @@ from prototype_common_utils import make_label ...@@ -17,7 +17,7 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_simple_tensor from torchvision.transforms.v2.utils import check_type, is_simple_tensor
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference():
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
......
...@@ -666,19 +666,19 @@ class TestTransform: ...@@ -666,19 +666,19 @@ class TestTransform:
t(inpt) t(inpt)
class TestToImageTensor: class TestToImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch( fn = mocker.patch(
"torchvision.transforms.v2.functional.to_image_tensor", "torchvision.transforms.v2.functional.to_image",
return_value=torch.rand(1, 3, 8, 8), return_value=torch.rand(1, 3, 8, 8),
) )
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor() transform = transforms.ToImage()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int): if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
...@@ -686,30 +686,13 @@ class TestToImageTensor: ...@@ -686,30 +686,13 @@ class TestToImageTensor:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
class TestToImagePIL:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
transform(inpt)
if inpt_type in (datapoints.BoundingBoxes, PIL.Image.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToPILImage: class TestToPILImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage() transform = transforms.ToPILImage()
...@@ -1013,7 +996,7 @@ def test_antialias_warning(): ...@@ -1013,7 +996,7 @@ def test_antialias_warning():
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int)) @pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) @pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
...@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): ...@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
@pytest.mark.parametrize("sanitize", (True, False)) @pytest.mark.parametrize("sanitize", (True, False))
def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str ...@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import query_size from torchvision.transforms.v2.utils import query_size
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
...@@ -630,7 +630,7 @@ def check_call_consistency( ...@@ -630,7 +630,7 @@ def check_call_consistency(
) )
if image.ndim == 3 and supports_pil: if image.ndim == 3 and supports_pil:
image_pil = to_image_pil(image) image_pil = to_pil_image(image)
try: try:
torch.manual_seed(0) torch.manual_seed(0)
...@@ -869,7 +869,7 @@ class TestToTensorTransforms: ...@@ -869,7 +869,7 @@ class TestToTensorTransforms:
legacy_transform = legacy_transforms.PILToTensor() legacy_transform = legacy_transforms.PILToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image) image_pil = to_pil_image(image)
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
...@@ -879,7 +879,7 @@ class TestToTensorTransforms: ...@@ -879,7 +879,7 @@ class TestToTensorTransforms:
legacy_transform = legacy_transforms.ToTensor() legacy_transform = legacy_transforms.ToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image) image_pil = to_pil_image(image)
image_numpy = np.array(image_pil) image_numpy = np.array(image_pil)
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
...@@ -1088,7 +1088,7 @@ class TestRefDetTransforms: ...@@ -1088,7 +1088,7 @@ class TestRefDetTransforms:
def make_label(extra_dims, categories): def make_label(extra_dims, categories):
return torch.randint(categories, extra_dims, dtype=torch.int64) return torch.randint(categories, extra_dims, dtype=torch.int64)
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -1192,7 +1192,7 @@ class TestRefSegTransforms: ...@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
conv_fns = [] conv_fns = []
if supports_pil: if supports_pil:
conv_fns.append(to_image_pil) conv_fns.append(to_pil_image)
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
...@@ -1201,8 +1201,8 @@ class TestRefSegTransforms: ...@@ -1201,8 +1201,8 @@ class TestRefSegTransforms:
dp = (conv_fn(datapoint_image), datapoint_mask) dp = (conv_fn(datapoint_image), datapoint_mask)
dp_ref = ( dp_ref = (
to_image_pil(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor), to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
to_image_pil(datapoint_mask), to_pil_image(datapoint_mask),
) )
yield dp, dp_ref yield dp, dp_ref
......
...@@ -280,12 +280,12 @@ class TestKernels: ...@@ -280,12 +280,12 @@ class TestKernels:
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs) adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel( actual = info.kernel(
F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True), F.to_dtype_image(input, dtype=torch.float32, scale=True),
*adapted_other_args, *adapted_other_args,
**adapted_kwargs, **adapted_kwargs,
) )
expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True) expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
assert_close( assert_close(
actual, actual,
...@@ -377,7 +377,7 @@ class TestDispatchers: ...@@ -377,7 +377,7 @@ class TestDispatchers:
if image_datapoint.ndim > 3: if image_datapoint.ndim > 3:
pytest.skip("Input is batched") pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_datapoint) image_pil = F.to_pil_image(image_datapoint)
output = info.dispatcher(image_pil, *other_args, **kwargs) output = info.dispatcher(image_pil, *other_args, **kwargs)
...@@ -470,7 +470,7 @@ class TestDispatchers: ...@@ -470,7 +470,7 @@ class TestDispatchers:
(F.hflip, F.horizontal_flip), (F.hflip, F.horizontal_flip),
(F.vflip, F.vertical_flip), (F.vflip, F.vertical_flip),
(F.get_image_num_channels, F.get_num_channels), (F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil), (F.to_pil_image, F.to_pil_image),
(F.elastic_transform, F.elastic), (F.elastic_transform, F.elastic),
(F.to_grayscale, F.rgb_to_grayscale), (F.to_grayscale, F.rgb_to_grayscale),
] ]
...@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels): ...@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
mean = image.mean(dim=(1, 2)).tolist() mean = image.mean(dim=(1, 2)).tolist()
std = image.std(dim=(1, 2)).tolist() std = image.std(dim=(1, 2)).tolist()
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
class TestClampBoundingBoxes: class TestClampBoundingBoxes:
...@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size): ...@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size):
_, image_height, image_width = mask.shape _, image_height, image_width = mask.shape
if crop_width > image_height or crop_height > image_width: if crop_width > image_height or crop_height > image_width:
padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
mask = F.pad_image_tensor(mask, padding, fill=0) mask = F.pad_image(mask, padding, fill=0)
left = round((image_width - crop_width) * 0.5) left = round((image_width - crop_width) * 0.5)
top = round((image_height - crop_height) * 0.5) top = round((image_height - crop_height) * 0.5)
...@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size): ...@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size):
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma): def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
fn = F.gaussian_blur_image_tensor fn = F.gaussian_blur_image
# true_cv2_results = { # true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) # # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
...@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, ...@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
PIL.Image.new("RGB", (32, 32), 122), PIL.Image.new("RGB", (32, 32), 122),
], ],
) )
def test_to_image_tensor(inpt): def test_to_image(inpt):
output = F.to_image_tensor(inpt) output = F.to_image(inpt)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
assert output.shape == (3, 32, 32) assert output.shape == (3, 32, 32)
...@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt): ...@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt):
], ],
) )
@pytest.mark.parametrize("mode", [None, "RGB"]) @pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode): def test_to_pil_image(inpt, mode):
output = F.to_image_pil(inpt, mode=mode) output = F.to_pil_image(inpt, mode=mode)
assert isinstance(output, PIL.Image.Image) assert isinstance(output, PIL.Image.Image)
assert np.asarray(inpt).sum() == np.asarray(output).sum() assert np.asarray(inpt).sum() == np.asarray(output).sum()
...@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode): ...@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode):
def test_equalize_image_tensor_edge_cases(): def test_equalize_image_tensor_edge_cases():
inpt = torch.zeros(3, 200, 200, dtype=torch.uint8) inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
output = F.equalize_image_tensor(inpt) output = F.equalize_image(inpt)
torch.testing.assert_close(inpt, output) torch.testing.assert_close(inpt, output)
inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8) inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
inpt[..., 100:, 100:] = 1 inpt[..., 100:, 100:] = 1
output = F.equalize_image_tensor(inpt) output = F.equalize_image(inpt)
assert output.unique().tolist() == [0, 255] assert output.unique().tolist() == [0, 255]
...@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device): ...@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device):
# TODO: We can remove this test and related torchvision workaround # TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor], [info for info in KERNEL_INFOS if info.kernel is F.resize_image],
args_kwargs_fn=lambda info: info.reference_inputs_fn(), args_kwargs_fn=lambda info: info.reference_inputs_fn(),
) )
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
......
...@@ -437,7 +437,7 @@ class TestResize: ...@@ -437,7 +437,7 @@ class TestResize:
check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol) check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol)
check_kernel( check_kernel(
F.resize_image_tensor, F.resize_image,
make_image(self.INPUT_SIZE, dtype=dtype, device=device), make_image(self.INPUT_SIZE, dtype=dtype, device=device),
size=size, size=size,
interpolation=interpolation, interpolation=interpolation,
...@@ -495,9 +495,9 @@ class TestResize: ...@@ -495,9 +495,9 @@ class TestResize:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.resize_image_tensor, torch.Tensor), (F.resize_image, torch.Tensor),
(F.resize_image_pil, PIL.Image.Image), (F._resize_image_pil, PIL.Image.Image),
(F.resize_image_tensor, datapoints.Image), (F.resize_image, datapoints.Image),
(F.resize_bounding_boxes, datapoints.BoundingBoxes), (F.resize_bounding_boxes, datapoints.BoundingBoxes),
(F.resize_mask, datapoints.Mask), (F.resize_mask, datapoints.Mask),
(F.resize_video, datapoints.Video), (F.resize_video, datapoints.Video),
...@@ -541,9 +541,7 @@ class TestResize: ...@@ -541,9 +541,7 @@ class TestResize:
image = make_image(self.INPUT_SIZE, dtype=torch.uint8) image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True) actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True)
expected = F.to_image_tensor( expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg))
F.resize(F.to_image_pil(image), size=size, interpolation=interpolation, **max_size_kwarg)
)
self._check_output_size(image, actual, size=size, **max_size_kwarg) self._check_output_size(image, actual, size=size, **max_size_kwarg)
torch.testing.assert_close(actual, expected, atol=1, rtol=0) torch.testing.assert_close(actual, expected, atol=1, rtol=0)
...@@ -739,7 +737,7 @@ class TestHorizontalFlip: ...@@ -739,7 +737,7 @@ class TestHorizontalFlip:
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device): def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.horizontal_flip_image_tensor, make_image(dtype=dtype, device=device)) check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
...@@ -770,9 +768,9 @@ class TestHorizontalFlip: ...@@ -770,9 +768,9 @@ class TestHorizontalFlip:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.horizontal_flip_image_tensor, torch.Tensor), (F.horizontal_flip_image, torch.Tensor),
(F.horizontal_flip_image_pil, PIL.Image.Image), (F._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image_tensor, datapoints.Image), (F.horizontal_flip_image, datapoints.Image),
(F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes), (F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes),
(F.horizontal_flip_mask, datapoints.Mask), (F.horizontal_flip_mask, datapoints.Mask),
(F.horizontal_flip_video, datapoints.Video), (F.horizontal_flip_video, datapoints.Video),
...@@ -796,7 +794,7 @@ class TestHorizontalFlip: ...@@ -796,7 +794,7 @@ class TestHorizontalFlip:
image = make_image(dtype=torch.uint8, device="cpu") image = make_image(dtype=torch.uint8, device="cpu")
actual = fn(image) actual = fn(image)
expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image))) expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
...@@ -900,7 +898,7 @@ class TestAffine: ...@@ -900,7 +898,7 @@ class TestAffine:
if param == "fill": if param == "fill":
value = adapt_fill(value, dtype=dtype) value = adapt_fill(value, dtype=dtype)
self._check_kernel( self._check_kernel(
F.affine_image_tensor, F.affine_image,
make_image(dtype=dtype, device=device), make_image(dtype=dtype, device=device),
**{param: value}, **{param: value},
check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))), check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))),
...@@ -946,9 +944,9 @@ class TestAffine: ...@@ -946,9 +944,9 @@ class TestAffine:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.affine_image_tensor, torch.Tensor), (F.affine_image, torch.Tensor),
(F.affine_image_pil, PIL.Image.Image), (F._affine_image_pil, PIL.Image.Image),
(F.affine_image_tensor, datapoints.Image), (F.affine_image, datapoints.Image),
(F.affine_bounding_boxes, datapoints.BoundingBoxes), (F.affine_bounding_boxes, datapoints.BoundingBoxes),
(F.affine_mask, datapoints.Mask), (F.affine_mask, datapoints.Mask),
(F.affine_video, datapoints.Video), (F.affine_video, datapoints.Video),
...@@ -991,9 +989,9 @@ class TestAffine: ...@@ -991,9 +989,9 @@ class TestAffine:
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
) )
expected = F.to_image_tensor( expected = F.to_image(
F.affine( F.affine(
F.to_image_pil(image), F.to_pil_image(image),
angle=angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -1026,7 +1024,7 @@ class TestAffine: ...@@ -1026,7 +1024,7 @@ class TestAffine:
actual = transform(image) actual = transform(image)
torch.manual_seed(seed) torch.manual_seed(seed)
expected = F.to_image_tensor(transform(F.to_image_pil(image))) expected = F.to_image(transform(F.to_pil_image(image)))
mae = (actual.float() - expected.float()).abs().mean() mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
...@@ -1204,7 +1202,7 @@ class TestVerticalFlip: ...@@ -1204,7 +1202,7 @@ class TestVerticalFlip:
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device): def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.vertical_flip_image_tensor, make_image(dtype=dtype, device=device)) check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
...@@ -1235,9 +1233,9 @@ class TestVerticalFlip: ...@@ -1235,9 +1233,9 @@ class TestVerticalFlip:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.vertical_flip_image_tensor, torch.Tensor), (F.vertical_flip_image, torch.Tensor),
(F.vertical_flip_image_pil, PIL.Image.Image), (F._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image_tensor, datapoints.Image), (F.vertical_flip_image, datapoints.Image),
(F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes), (F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes),
(F.vertical_flip_mask, datapoints.Mask), (F.vertical_flip_mask, datapoints.Mask),
(F.vertical_flip_video, datapoints.Video), (F.vertical_flip_video, datapoints.Video),
...@@ -1259,7 +1257,7 @@ class TestVerticalFlip: ...@@ -1259,7 +1257,7 @@ class TestVerticalFlip:
image = make_image(dtype=torch.uint8, device="cpu") image = make_image(dtype=torch.uint8, device="cpu")
actual = fn(image) actual = fn(image)
expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image))) expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
...@@ -1339,7 +1337,7 @@ class TestRotate: ...@@ -1339,7 +1337,7 @@ class TestRotate:
if param != "angle": if param != "angle":
kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
check_kernel( check_kernel(
F.rotate_image_tensor, F.rotate_image,
make_image(dtype=dtype, device=device), make_image(dtype=dtype, device=device),
**kwargs, **kwargs,
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))), check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
...@@ -1385,9 +1383,9 @@ class TestRotate: ...@@ -1385,9 +1383,9 @@ class TestRotate:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.rotate_image_tensor, torch.Tensor), (F.rotate_image, torch.Tensor),
(F.rotate_image_pil, PIL.Image.Image), (F._rotate_image_pil, PIL.Image.Image),
(F.rotate_image_tensor, datapoints.Image), (F.rotate_image, datapoints.Image),
(F.rotate_bounding_boxes, datapoints.BoundingBoxes), (F.rotate_bounding_boxes, datapoints.BoundingBoxes),
(F.rotate_mask, datapoints.Mask), (F.rotate_mask, datapoints.Mask),
(F.rotate_video, datapoints.Video), (F.rotate_video, datapoints.Video),
...@@ -1419,9 +1417,9 @@ class TestRotate: ...@@ -1419,9 +1417,9 @@ class TestRotate:
fill = adapt_fill(fill, dtype=torch.uint8) fill = adapt_fill(fill, dtype=torch.uint8)
actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill)
expected = F.to_image_tensor( expected = F.to_image(
F.rotate( F.rotate(
F.to_image_pil(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill
) )
) )
...@@ -1452,7 +1450,7 @@ class TestRotate: ...@@ -1452,7 +1450,7 @@ class TestRotate:
actual = transform(image) actual = transform(image)
torch.manual_seed(seed) torch.manual_seed(seed)
expected = F.to_image_tensor(transform(F.to_image_pil(image))) expected = F.to_image(transform(F.to_pil_image(image)))
mae = (actual.float() - expected.float()).abs().mean() mae = (actual.float() - expected.float()).abs().mean()
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6 assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
...@@ -1621,8 +1619,8 @@ class TestToDtype: ...@@ -1621,8 +1619,8 @@ class TestToDtype:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.to_dtype_image_tensor, make_image_tensor), (F.to_dtype_image, make_image_tensor),
(F.to_dtype_image_tensor, make_image), (F.to_dtype_image, make_image),
(F.to_dtype_video, make_video), (F.to_dtype_video, make_video),
], ],
) )
...@@ -1801,7 +1799,7 @@ class TestAdjustBrightness: ...@@ -1801,7 +1799,7 @@ class TestAdjustBrightness:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.adjust_brightness_image_tensor, make_image), (F.adjust_brightness_image, make_image),
(F.adjust_brightness_video, make_video), (F.adjust_brightness_video, make_video),
], ],
) )
...@@ -1817,9 +1815,9 @@ class TestAdjustBrightness: ...@@ -1817,9 +1815,9 @@ class TestAdjustBrightness:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.adjust_brightness_image_tensor, torch.Tensor), (F.adjust_brightness_image, torch.Tensor),
(F.adjust_brightness_image_pil, PIL.Image.Image), (F._adjust_brightness_image_pil, PIL.Image.Image),
(F.adjust_brightness_image_tensor, datapoints.Image), (F.adjust_brightness_image, datapoints.Image),
(F.adjust_brightness_video, datapoints.Video), (F.adjust_brightness_video, datapoints.Video),
], ],
) )
...@@ -1831,7 +1829,7 @@ class TestAdjustBrightness: ...@@ -1831,7 +1829,7 @@ class TestAdjustBrightness:
image = make_image(dtype=torch.uint8, device="cpu") image = make_image(dtype=torch.uint8, device="cpu")
actual = F.adjust_brightness(image, brightness_factor=brightness_factor) actual = F.adjust_brightness(image, brightness_factor=brightness_factor)
expected = F.to_image_tensor(F.adjust_brightness(F.to_image_pil(image), brightness_factor=brightness_factor)) expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
...@@ -1979,9 +1977,9 @@ class TestShapeGetters: ...@@ -1979,9 +1977,9 @@ class TestShapeGetters:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.get_dimensions_image_tensor, make_image_tensor), (F.get_dimensions_image, make_image_tensor),
(F.get_dimensions_image_pil, make_image_pil), (F._get_dimensions_image_pil, make_image_pil),
(F.get_dimensions_image_tensor, make_image), (F.get_dimensions_image, make_image),
(F.get_dimensions_video, make_video), (F.get_dimensions_video, make_video),
], ],
) )
...@@ -1996,9 +1994,9 @@ class TestShapeGetters: ...@@ -1996,9 +1994,9 @@ class TestShapeGetters:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.get_num_channels_image_tensor, make_image_tensor), (F.get_num_channels_image, make_image_tensor),
(F.get_num_channels_image_pil, make_image_pil), (F._get_num_channels_image_pil, make_image_pil),
(F.get_num_channels_image_tensor, make_image), (F.get_num_channels_image, make_image),
(F.get_num_channels_video, make_video), (F.get_num_channels_video, make_video),
], ],
) )
...@@ -2012,9 +2010,9 @@ class TestShapeGetters: ...@@ -2012,9 +2010,9 @@ class TestShapeGetters:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.get_size_image_tensor, make_image_tensor), (F.get_size_image, make_image_tensor),
(F.get_size_image_pil, make_image_pil), (F._get_size_image_pil, make_image_pil),
(F.get_size_image_tensor, make_image), (F.get_size_image, make_image),
(F.get_size_bounding_boxes, make_bounding_box), (F.get_size_bounding_boxes, make_bounding_box),
(F.get_size_mask, make_detection_mask), (F.get_size_mask, make_detection_mask),
(F.get_size_mask, make_segmentation_mask), (F.get_size_mask, make_segmentation_mask),
...@@ -2101,7 +2099,7 @@ class TestRegisterKernel: ...@@ -2101,7 +2099,7 @@ class TestRegisterKernel:
F.register_kernel(F.resize, object) F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"): with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) F.register_kernel(F.resize, datapoints.Image)(F.resize_image)
class CustomDatapoint(datapoints.Datapoint): class CustomDatapoint(datapoints.Datapoint):
pass pass
...@@ -2119,9 +2117,9 @@ class TestGetKernel: ...@@ -2119,9 +2117,9 @@ class TestGetKernel:
# We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# would also be fine # would also be fine
KERNELS = { KERNELS = {
torch.Tensor: F.resize_image_tensor, torch.Tensor: F.resize_image,
PIL.Image.Image: F.resize_image_pil, PIL.Image.Image: F._resize_image_pil,
datapoints.Image: F.resize_image_tensor, datapoints.Image: F.resize_image,
datapoints.BoundingBoxes: F.resize_bounding_boxes, datapoints.BoundingBoxes: F.resize_bounding_boxes,
datapoints.Mask: F.resize_mask, datapoints.Mask: F.resize_mask,
datapoints.Video: F.resize_video, datapoints.Video: F.resize_video,
...@@ -2217,10 +2215,10 @@ class TestPermuteChannels: ...@@ -2217,10 +2215,10 @@ class TestPermuteChannels:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.permute_channels_image_tensor, make_image_tensor), (F.permute_channels_image, make_image_tensor),
# FIXME # FIXME
# check_kernel does not support PIL kernel, but it should # check_kernel does not support PIL kernel, but it should
(F.permute_channels_image_tensor, make_image), (F.permute_channels_image, make_image),
(F.permute_channels_video, make_video), (F.permute_channels_video, make_video),
], ],
) )
...@@ -2236,9 +2234,9 @@ class TestPermuteChannels: ...@@ -2236,9 +2234,9 @@ class TestPermuteChannels:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.permute_channels_image_tensor, torch.Tensor), (F.permute_channels_image, torch.Tensor),
(F.permute_channels_image_pil, PIL.Image.Image), (F._permute_channels_image_pil, PIL.Image.Image),
(F.permute_channels_image_tensor, datapoints.Image), (F.permute_channels_image, datapoints.Image),
(F.permute_channels_video, datapoints.Video), (F.permute_channels_video, datapoints.Video),
], ],
) )
......
...@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils ...@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils
from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import has_all, has_any from torchvision.transforms.v2.utils import has_all, has_any
...@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE) ...@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE)
True, True,
), ),
( (
(to_image_pil(IMAGE),), (to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True, True,
), ),
......
...@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [ ...@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.crop, F.crop,
kernels={ kernels={
datapoints.Image: F.crop_image_tensor, datapoints.Image: F.crop_image,
datapoints.Video: F.crop_video, datapoints.Video: F.crop_video,
datapoints.BoundingBoxes: F.crop_bounding_boxes, datapoints.BoundingBoxes: F.crop_bounding_boxes,
datapoints.Mask: F.crop_mask, datapoints.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
kernels={ kernels={
datapoints.Image: F.resized_crop_image_tensor, datapoints.Image: F.resized_crop_image,
datapoints.Video: F.resized_crop_video, datapoints.Video: F.resized_crop_video,
datapoints.BoundingBoxes: F.resized_crop_bounding_boxes, datapoints.BoundingBoxes: F.resized_crop_bounding_boxes,
datapoints.Mask: F.resized_crop_mask, datapoints.Mask: F.resized_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.pad, F.pad,
kernels={ kernels={
datapoints.Image: F.pad_image_tensor, datapoints.Image: F.pad_image,
datapoints.Video: F.pad_video, datapoints.Video: F.pad_video,
datapoints.BoundingBoxes: F.pad_bounding_boxes, datapoints.BoundingBoxes: F.pad_bounding_boxes,
datapoints.Mask: F.pad_mask, datapoints.Mask: F.pad_mask,
}, },
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
*xfails_pil( *xfails_pil(
reason=( reason=(
...@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [ ...@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.perspective, F.perspective,
kernels={ kernels={
datapoints.Image: F.perspective_image_tensor, datapoints.Image: F.perspective_image,
datapoints.Video: F.perspective_video, datapoints.Video: F.perspective_video,
datapoints.BoundingBoxes: F.perspective_bounding_boxes, datapoints.BoundingBoxes: F.perspective_bounding_boxes,
datapoints.Mask: F.perspective_mask, datapoints.Mask: F.perspective_mask,
}, },
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[ test_marks=[
*xfails_pil_if_fill_sequence_needs_broadcast, *xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"), xfail_jit_python_scalar_arg("fill"),
...@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [ ...@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.elastic, F.elastic,
kernels={ kernels={
datapoints.Image: F.elastic_image_tensor, datapoints.Image: F.elastic_image,
datapoints.Video: F.elastic_video, datapoints.Video: F.elastic_video,
datapoints.BoundingBoxes: F.elastic_bounding_boxes, datapoints.BoundingBoxes: F.elastic_bounding_boxes,
datapoints.Mask: F.elastic_mask, datapoints.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")], test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
datapoints.Image: F.center_crop_image_tensor, datapoints.Image: F.center_crop_image,
datapoints.Video: F.center_crop_video, datapoints.Video: F.center_crop_video,
datapoints.BoundingBoxes: F.center_crop_bounding_boxes, datapoints.BoundingBoxes: F.center_crop_bounding_boxes,
datapoints.Mask: F.center_crop_mask, datapoints.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [ ...@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.gaussian_blur, F.gaussian_blur,
kernels={ kernels={
datapoints.Image: F.gaussian_blur_image_tensor, datapoints.Image: F.gaussian_blur_image,
datapoints.Video: F.gaussian_blur_video, datapoints.Video: F.gaussian_blur_video,
}, },
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
...@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [ ...@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.equalize, F.equalize,
kernels={ kernels={
datapoints.Image: F.equalize_image_tensor, datapoints.Image: F.equalize_image,
datapoints.Video: F.equalize_video, datapoints.Video: F.equalize_video,
}, },
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"), pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.invert, F.invert,
kernels={ kernels={
datapoints.Image: F.invert_image_tensor, datapoints.Image: F.invert_image,
datapoints.Video: F.invert_video, datapoints.Video: F.invert_video,
}, },
pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"), pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.posterize, F.posterize,
kernels={ kernels={
datapoints.Image: F.posterize_image_tensor, datapoints.Image: F.posterize_image,
datapoints.Video: F.posterize_video, datapoints.Video: F.posterize_video,
}, },
pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"), pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.solarize, F.solarize,
kernels={ kernels={
datapoints.Image: F.solarize_image_tensor, datapoints.Image: F.solarize_image,
datapoints.Video: F.solarize_video, datapoints.Video: F.solarize_video,
}, },
pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"), pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.autocontrast, F.autocontrast,
kernels={ kernels={
datapoints.Image: F.autocontrast_image_tensor, datapoints.Image: F.autocontrast_image,
datapoints.Video: F.autocontrast_video, datapoints.Video: F.autocontrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"), pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_sharpness, F.adjust_sharpness,
kernels={ kernels={
datapoints.Image: F.adjust_sharpness_image_tensor, datapoints.Image: F.adjust_sharpness_image,
datapoints.Video: F.adjust_sharpness_video, datapoints.Video: F.adjust_sharpness_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.erase, F.erase,
kernels={ kernels={
datapoints.Image: F.erase_image_tensor, datapoints.Image: F.erase_image,
datapoints.Video: F.erase_video, datapoints.Video: F.erase_video,
}, },
pil_kernel_info=PILKernelInfo(F.erase_image_pil), pil_kernel_info=PILKernelInfo(F._erase_image_pil),
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
...@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [ ...@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.adjust_contrast, F.adjust_contrast,
kernels={ kernels={
datapoints.Image: F.adjust_contrast_image_tensor, datapoints.Image: F.adjust_contrast_image,
datapoints.Video: F.adjust_contrast_video, datapoints.Video: F.adjust_contrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_gamma, F.adjust_gamma,
kernels={ kernels={
datapoints.Image: F.adjust_gamma_image_tensor, datapoints.Image: F.adjust_gamma_image,
datapoints.Video: F.adjust_gamma_video, datapoints.Video: F.adjust_gamma_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_hue, F.adjust_hue,
kernels={ kernels={
datapoints.Image: F.adjust_hue_image_tensor, datapoints.Image: F.adjust_hue_image,
datapoints.Video: F.adjust_hue_video, datapoints.Video: F.adjust_hue_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_saturation, F.adjust_saturation,
kernels={ kernels={
datapoints.Image: F.adjust_saturation_image_tensor, datapoints.Image: F.adjust_saturation_image,
datapoints.Video: F.adjust_saturation_video, datapoints.Video: F.adjust_saturation_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.five_crop, F.five_crop,
kernels={ kernels={
datapoints.Image: F.five_crop_image_tensor, datapoints.Image: F.five_crop_image,
datapoints.Video: F.five_crop_video, datapoints.Video: F.five_crop_video,
}, },
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
*multi_crop_skips, *multi_crop_skips,
...@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [ ...@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.ten_crop, F.ten_crop,
kernels={ kernels={
datapoints.Image: F.ten_crop_image_tensor, datapoints.Image: F.ten_crop_image,
datapoints.Video: F.ten_crop_video, datapoints.Video: F.ten_crop_video,
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
*multi_crop_skips, *multi_crop_skips,
], ],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.normalize, F.normalize,
kernels={ kernels={
datapoints.Image: F.normalize_image_tensor, datapoints.Image: F.normalize_image,
datapoints.Video: F.normalize_video, datapoints.Video: F.normalize_video,
}, },
test_marks=[ test_marks=[
......
...@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel): ...@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel):
f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}" f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
) )
input_pil = F.to_image_pil(input_tensor) input_pil = F.to_pil_image(input_tensor)
output_pil = pil_kernel(input_pil, *other_args, **kwargs) output_pil = pil_kernel(input_pil, *other_args, **kwargs)
if not isinstance(output_pil, PIL.Image.Image): if not isinstance(output_pil, PIL.Image.Image):
return output_pil return output_pil
output_tensor = F.to_image_tensor(output_pil) output_tensor = F.to_image(output_pil)
# 2D mask shenanigans # 2D mask shenanigans
if output_tensor.ndim == 2 and input_tensor.ndim == 3: if output_tensor.ndim == 2 and input_tensor.ndim == 3:
...@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes(): ...@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.crop_image_tensor, F.crop_image,
kernel_name="crop_image_tensor", kernel_name="crop_image_tensor",
sample_inputs_fn=sample_inputs_crop_image_tensor, sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F._crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor, reference_inputs_fn=reference_inputs_crop_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
...@@ -347,7 +347,7 @@ KERNEL_INFOS.extend( ...@@ -347,7 +347,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.crop_mask, F.crop_mask,
sample_inputs_fn=sample_inputs_crop_mask, sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F._crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask, reference_inputs_fn=reference_inputs_crop_mask,
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
...@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs): ...@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
}: }:
raise pytest.UsageError("Anti-aliasing is always active in PIL") raise pytest.UsageError("Anti-aliasing is always active in PIL")
return F.resized_crop_image_pil(*args, **kwargs) return F._resized_crop_image_pil(*args, **kwargs)
def reference_inputs_resized_crop_image_tensor(): def reference_inputs_resized_crop_image_tensor():
...@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video(): ...@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.resized_crop_image_tensor, F.resized_crop_image,
sample_inputs_fn=sample_inputs_resized_crop_image_tensor, sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor, reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor, reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
...@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs): ...@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs):
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.pad_image_tensor, F.pad_image,
sample_inputs_fn=sample_inputs_pad_image_tensor, sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F._pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor, reference_inputs_fn=reference_inputs_pad_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
...@@ -595,7 +595,7 @@ KERNEL_INFOS.extend( ...@@ -595,7 +595,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.pad_mask, F.pad_mask,
sample_inputs_fn=sample_inputs_pad_mask, sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F._pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask, reference_inputs_fn=reference_inputs_pad_mask,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
), ),
...@@ -690,9 +690,9 @@ def sample_inputs_perspective_video(): ...@@ -690,9 +690,9 @@ def sample_inputs_perspective_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.perspective_image_tensor, F.perspective_image,
sample_inputs_fn=sample_inputs_perspective_image_tensor, sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F._perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={ closeness_kwargs={
...@@ -715,7 +715,7 @@ KERNEL_INFOS.extend( ...@@ -715,7 +715,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.perspective_mask, F.perspective_mask,
sample_inputs_fn=sample_inputs_perspective_mask, sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F._perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask, reference_inputs_fn=reference_inputs_perspective_mask,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -786,7 +786,7 @@ def sample_inputs_elastic_video(): ...@@ -786,7 +786,7 @@ def sample_inputs_elastic_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.elastic_image_tensor, F.elastic_image,
sample_inputs_fn=sample_inputs_elastic_image_tensor, sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
...@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video(): ...@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.center_crop_image_tensor, F.center_crop_image,
sample_inputs_fn=sample_inputs_center_crop_image_tensor, sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor, reference_inputs_fn=reference_inputs_center_crop_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
test_marks=[ test_marks=[
...@@ -889,7 +889,7 @@ KERNEL_INFOS.extend( ...@@ -889,7 +889,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.center_crop_mask, F.center_crop_mask,
sample_inputs_fn=sample_inputs_center_crop_mask, sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask, reference_inputs_fn=reference_inputs_center_crop_mask,
float32_vs_uint8=True, float32_vs_uint8=True,
test_marks=[ test_marks=[
...@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video(): ...@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.gaussian_blur_image_tensor, F.gaussian_blur_image,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=cuda_vs_cpu_pixel_difference(), closeness_kwargs=cuda_vs_cpu_pixel_difference(),
test_marks=[ test_marks=[
...@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video(): ...@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.equalize_image_tensor, F.equalize_image,
kernel_name="equalize_image_tensor", kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor, sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil), reference_fn=pil_reference_wrapper(F._equalize_image_pil),
float32_vs_uint8=True, float32_vs_uint8=True,
reference_inputs_fn=reference_inputs_equalize_image_tensor, reference_inputs_fn=reference_inputs_equalize_image_tensor,
), ),
...@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video(): ...@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.invert_image_tensor, F.invert_image,
kernel_name="invert_image_tensor", kernel_name="invert_image_tensor",
sample_inputs_fn=sample_inputs_invert_image_tensor, sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil), reference_fn=pil_reference_wrapper(F._invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor, reference_inputs_fn=reference_inputs_invert_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
...@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video(): ...@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.posterize_image_tensor, F.posterize_image,
kernel_name="posterize_image_tensor", kernel_name="posterize_image_tensor",
sample_inputs_fn=sample_inputs_posterize_image_tensor, sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil), reference_fn=pil_reference_wrapper(F._posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor, reference_inputs_fn=reference_inputs_posterize_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
...@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video(): ...@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.solarize_image_tensor, F.solarize_image,
kernel_name="solarize_image_tensor", kernel_name="solarize_image_tensor",
sample_inputs_fn=sample_inputs_solarize_image_tensor, sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil), reference_fn=pil_reference_wrapper(F._solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor, reference_inputs_fn=reference_inputs_solarize_image_tensor,
float32_vs_uint8=uint8_to_float32_threshold_adapter, float32_vs_uint8=uint8_to_float32_threshold_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
...@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video(): ...@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.autocontrast_image_tensor, F.autocontrast_image,
kernel_name="autocontrast_image_tensor", kernel_name="autocontrast_image_tensor",
sample_inputs_fn=sample_inputs_autocontrast_image_tensor, sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), reference_fn=pil_reference_wrapper(F._autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor, reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video(): ...@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_sharpness_image_tensor, F.adjust_sharpness_image,
kernel_name="adjust_sharpness_image_tensor", kernel_name="adjust_sharpness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), reference_fn=pil_reference_wrapper(F._adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(2), closeness_kwargs=float32_vs_uint8_pixel_difference(2),
...@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video(): ...@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.erase_image_tensor, F.erase_image,
kernel_name="erase_image_tensor", kernel_name="erase_image_tensor",
sample_inputs_fn=sample_inputs_erase_image_tensor, sample_inputs_fn=sample_inputs_erase_image_tensor,
), ),
...@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video(): ...@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_contrast_image_tensor, F.adjust_contrast_image,
kernel_name="adjust_contrast_image_tensor", kernel_name="adjust_contrast_image_tensor",
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), reference_fn=pil_reference_wrapper(F._adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video(): ...@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_gamma_image_tensor, F.adjust_gamma_image,
kernel_name="adjust_gamma_image_tensor", kernel_name="adjust_gamma_image_tensor",
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), reference_fn=pil_reference_wrapper(F._adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video(): ...@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_hue_image_tensor, F.adjust_hue_image,
kernel_name="adjust_hue_image_tensor", kernel_name="adjust_hue_image_tensor",
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), reference_fn=pil_reference_wrapper(F._adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video(): ...@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_saturation_image_tensor, F.adjust_saturation_image,
kernel_name="adjust_saturation_image_tensor", kernel_name="adjust_saturation_image_tensor",
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), reference_fn=pil_reference_wrapper(F._adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel): ...@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs): def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)( return type(output)(
F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True) F.to_dtype_image(F.to_image(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output
for output_pil in output
) )
return wrapper return wrapper
...@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [ ...@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.five_crop_image_tensor, F.five_crop_image,
sample_inputs_fn=sample_inputs_five_crop_image_tensor, sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil), reference_fn=multi_crop_pil_reference_wrapper(F._five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor, reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
), ),
...@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend( ...@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend(
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
), ),
KernelInfo( KernelInfo(
F.ten_crop_image_tensor, F.ten_crop_image,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor, sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil), reference_fn=multi_crop_pil_reference_wrapper(F._ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor, reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
), ),
...@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video(): ...@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.normalize_image_tensor, F.normalize_image,
kernel_name="normalize_image_tensor", kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor, sample_inputs_fn=sample_inputs_normalize_image_tensor,
reference_fn=reference_normalize_image_tensor, reference_fn=reference_normalize_image_tensor,
......
...@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform): ...@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform):
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
images.append(obj) images.append(obj)
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj)) images.append(F.to_image(obj))
elif isinstance(obj, datapoints.BoundingBoxes): elif isinstance(obj, datapoints.BoundingBoxes):
bboxes.append(obj) bboxes.append(obj)
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, datapoints.Mask):
...@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform): ...@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform):
flat_sample[i] = datapoints.wrap(output_images[c0], like=obj) flat_sample[i] = datapoints.wrap(output_images[c0], like=obj)
c0 += 1 c0 += 1
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0]) flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1 c0 += 1
elif is_simple_tensor(obj): elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
......
...@@ -52,7 +52,7 @@ from ._misc import ( ...@@ -52,7 +52,7 @@ from ._misc import (
ToDtype, ToDtype,
) )
from ._temporal import UniformTemporalSubsample from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._type_conversion import PILToTensor, ToImage, ToPILImage
from ._deprecated import ToTensor # usort: skip from ._deprecated import ToTensor # usort: skip
......
...@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase): ...@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase):
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = datapoints.wrap(mix, like=orig_image_or_video) mix = datapoints.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_pil_image(mix)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
...@@ -26,7 +26,7 @@ class PILToTensor(Transform): ...@@ -26,7 +26,7 @@ class PILToTensor(Transform):
return F.pil_to_tensor(inpt) return F.pil_to_tensor(inpt)
class ToImageTensor(Transform): class ToImage(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image` """[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
; this does not scale values. ; this does not scale values.
...@@ -40,10 +40,10 @@ class ToImageTensor(Transform): ...@@ -40,10 +40,10 @@ class ToImageTensor(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image: ) -> datapoints.Image:
return F.to_image_tensor(inpt) return F.to_image(inpt)
class ToImagePIL(Transform): class ToPILImage(Transform):
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values. """[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
.. v2betastatus:: ToImagePIL transform .. v2betastatus:: ToImagePIL transform
...@@ -74,9 +74,4 @@ class ToImagePIL(Transform): ...@@ -74,9 +74,4 @@ class ToImagePIL(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image: ) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode) return F.to_pil_image(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
...@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip ...@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_boxes, clamp_bounding_boxes,
convert_format_bounding_boxes, convert_format_bounding_boxes,
get_dimensions_image_tensor, get_dimensions_image,
get_dimensions_image_pil, _get_dimensions_image_pil,
get_dimensions_video, get_dimensions_video,
get_dimensions, get_dimensions,
get_num_frames_video, get_num_frames_video,
get_num_frames, get_num_frames,
get_image_num_channels, get_image_num_channels,
get_num_channels_image_tensor, get_num_channels_image,
get_num_channels_image_pil, _get_num_channels_image_pil,
get_num_channels_video, get_num_channels_video,
get_num_channels, get_num_channels,
get_size_bounding_boxes, get_size_bounding_boxes,
get_size_image_tensor, get_size_image,
get_size_image_pil, _get_size_image_pil,
get_size_mask, get_size_mask,
get_size_video, get_size_video,
get_size, get_size,
) # usort: skip ) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video from ._augment import _erase_image_pil, erase, erase_image, erase_video
from ._color import ( from ._color import (
_adjust_brightness_image_pil,
_adjust_contrast_image_pil,
_adjust_gamma_image_pil,
_adjust_hue_image_pil,
_adjust_saturation_image_pil,
_adjust_sharpness_image_pil,
_autocontrast_image_pil,
_equalize_image_pil,
_invert_image_pil,
_permute_channels_image_pil,
_posterize_image_pil,
_rgb_to_grayscale_image_pil,
_solarize_image_pil,
adjust_brightness, adjust_brightness,
adjust_brightness_image_pil, adjust_brightness_image,
adjust_brightness_image_tensor,
adjust_brightness_video, adjust_brightness_video,
adjust_contrast, adjust_contrast,
adjust_contrast_image_pil, adjust_contrast_image,
adjust_contrast_image_tensor,
adjust_contrast_video, adjust_contrast_video,
adjust_gamma, adjust_gamma,
adjust_gamma_image_pil, adjust_gamma_image,
adjust_gamma_image_tensor,
adjust_gamma_video, adjust_gamma_video,
adjust_hue, adjust_hue,
adjust_hue_image_pil, adjust_hue_image,
adjust_hue_image_tensor,
adjust_hue_video, adjust_hue_video,
adjust_saturation, adjust_saturation,
adjust_saturation_image_pil, adjust_saturation_image,
adjust_saturation_image_tensor,
adjust_saturation_video, adjust_saturation_video,
adjust_sharpness, adjust_sharpness,
adjust_sharpness_image_pil, adjust_sharpness_image,
adjust_sharpness_image_tensor,
adjust_sharpness_video, adjust_sharpness_video,
autocontrast, autocontrast,
autocontrast_image_pil, autocontrast_image,
autocontrast_image_tensor,
autocontrast_video, autocontrast_video,
equalize, equalize,
equalize_image_pil, equalize_image,
equalize_image_tensor,
equalize_video, equalize_video,
invert, invert,
invert_image_pil, invert_image,
invert_image_tensor,
invert_video, invert_video,
permute_channels, permute_channels,
permute_channels_image_pil, permute_channels_image,
permute_channels_image_tensor,
permute_channels_video, permute_channels_video,
posterize, posterize,
posterize_image_pil, posterize_image,
posterize_image_tensor,
posterize_video, posterize_video,
rgb_to_grayscale, rgb_to_grayscale,
rgb_to_grayscale_image_pil, rgb_to_grayscale_image,
rgb_to_grayscale_image_tensor,
solarize, solarize,
solarize_image_pil, solarize_image,
solarize_image_tensor,
solarize_video, solarize_video,
to_grayscale, to_grayscale,
) )
from ._geometry import ( from ._geometry import (
_affine_image_pil,
_center_crop_image_pil,
_crop_image_pil,
_elastic_image_pil,
_five_crop_image_pil,
_horizontal_flip_image_pil,
_pad_image_pil,
_perspective_image_pil,
_resize_image_pil,
_resized_crop_image_pil,
_rotate_image_pil,
_ten_crop_image_pil,
_vertical_flip_image_pil,
affine, affine,
affine_bounding_boxes, affine_bounding_boxes,
affine_image_pil, affine_image,
affine_image_tensor,
affine_mask, affine_mask,
affine_video, affine_video,
center_crop, center_crop,
center_crop_bounding_boxes, center_crop_bounding_boxes,
center_crop_image_pil, center_crop_image,
center_crop_image_tensor,
center_crop_mask, center_crop_mask,
center_crop_video, center_crop_video,
crop, crop,
crop_bounding_boxes, crop_bounding_boxes,
crop_image_pil, crop_image,
crop_image_tensor,
crop_mask, crop_mask,
crop_video, crop_video,
elastic, elastic,
elastic_bounding_boxes, elastic_bounding_boxes,
elastic_image_pil, elastic_image,
elastic_image_tensor,
elastic_mask, elastic_mask,
elastic_transform, elastic_transform,
elastic_video, elastic_video,
five_crop, five_crop,
five_crop_image_pil, five_crop_image,
five_crop_image_tensor,
five_crop_video, five_crop_video,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip, horizontal_flip,
horizontal_flip_bounding_boxes, horizontal_flip_bounding_boxes,
horizontal_flip_image_pil, horizontal_flip_image,
horizontal_flip_image_tensor,
horizontal_flip_mask, horizontal_flip_mask,
horizontal_flip_video, horizontal_flip_video,
pad, pad,
pad_bounding_boxes, pad_bounding_boxes,
pad_image_pil, pad_image,
pad_image_tensor,
pad_mask, pad_mask,
pad_video, pad_video,
perspective, perspective,
perspective_bounding_boxes, perspective_bounding_boxes,
perspective_image_pil, perspective_image,
perspective_image_tensor,
perspective_mask, perspective_mask,
perspective_video, perspective_video,
resize, resize,
resize_bounding_boxes, resize_bounding_boxes,
resize_image_pil, resize_image,
resize_image_tensor,
resize_mask, resize_mask,
resize_video, resize_video,
resized_crop, resized_crop,
resized_crop_bounding_boxes, resized_crop_bounding_boxes,
resized_crop_image_pil, resized_crop_image,
resized_crop_image_tensor,
resized_crop_mask, resized_crop_mask,
resized_crop_video, resized_crop_video,
rotate, rotate,
rotate_bounding_boxes, rotate_bounding_boxes,
rotate_image_pil, rotate_image,
rotate_image_tensor,
rotate_mask, rotate_mask,
rotate_video, rotate_video,
ten_crop, ten_crop,
ten_crop_image_pil, ten_crop_image,
ten_crop_image_tensor,
ten_crop_video, ten_crop_video,
vertical_flip, vertical_flip,
vertical_flip_bounding_boxes, vertical_flip_bounding_boxes,
vertical_flip_image_pil, vertical_flip_image,
vertical_flip_image_tensor,
vertical_flip_mask, vertical_flip_mask,
vertical_flip_video, vertical_flip_video,
vflip, vflip,
) )
from ._misc import ( from ._misc import (
_gaussian_blur_image_pil,
convert_image_dtype, convert_image_dtype,
gaussian_blur, gaussian_blur,
gaussian_blur_image_pil, gaussian_blur_image,
gaussian_blur_image_tensor,
gaussian_blur_video, gaussian_blur_video,
normalize, normalize,
normalize_image_tensor, normalize_image,
normalize_video, normalize_video,
to_dtype, to_dtype,
to_dtype_image_tensor, to_dtype_image,
to_dtype_video, to_dtype_video,
) )
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image from ._type_conversion import pil_to_tensor, to_image, to_pil_image
from ._deprecated import get_image_size, to_tensor # usort: skip from ._deprecated import get_image_size, to_tensor # usort: skip
...@@ -18,7 +18,7 @@ def erase( ...@@ -18,7 +18,7 @@ def erase(
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
_log_api_usage_once(erase) _log_api_usage_once(erase)
...@@ -28,7 +28,7 @@ def erase( ...@@ -28,7 +28,7 @@ def erase(
@_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image) @_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor( def erase_image(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if not inplace: if not inplace:
...@@ -39,11 +39,11 @@ def erase_image_tensor( ...@@ -39,11 +39,11 @@ def erase_image_tensor(
@_register_kernel_internal(erase, PIL.Image.Image) @_register_kernel_internal(erase, PIL.Image.Image)
def erase_image_pil( def _erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(image) t_img = pil_to_tensor(image)
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
...@@ -51,4 +51,4 @@ def erase_image_pil( ...@@ -51,4 +51,4 @@ def erase_image_pil(
def erase_video( def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
...@@ -9,14 +9,14 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -9,14 +9,14 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor from ._misc import _num_value_bits, to_dtype_image
from ._type_conversion import pil_to_tensor, to_image_pil from ._type_conversion import pil_to_tensor, to_pil_image
from ._utils import _get_kernel, _register_kernel_internal from ._utils import _get_kernel, _register_kernel_internal
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels)
_log_api_usage_once(rgb_to_grayscale) _log_api_usage_once(rgb_to_grayscale)
...@@ -29,7 +29,7 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch. ...@@ -29,7 +29,7 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.
to_grayscale = rgb_to_grayscale to_grayscale = rgb_to_grayscale
def _rgb_to_grayscale_image_tensor( def _rgb_to_grayscale_image(
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
) -> torch.Tensor: ) -> torch.Tensor:
if image.shape[-3] == 1: if image.shape[-3] == 1:
...@@ -47,14 +47,14 @@ def _rgb_to_grayscale_image_tensor( ...@@ -47,14 +47,14 @@ def _rgb_to_grayscale_image_tensor(
@_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, torch.Tensor)
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image) @_register_kernel_internal(rgb_to_grayscale, datapoints.Image)
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) return _rgb_to_grayscale_image(image, num_output_channels=num_output_channels, preserve_dtype=True)
@_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image) @_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image)
def rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
if num_output_channels not in (1, 3): if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
return _FP.to_grayscale(image, num_output_channels=num_output_channels) return _FP.to_grayscale(image, num_output_channels=num_output_channels)
...@@ -71,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te ...@@ -71,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image(inpt, brightness_factor=brightness_factor)
_log_api_usage_once(adjust_brightness) _log_api_usage_once(adjust_brightness)
...@@ -81,7 +81,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten ...@@ -81,7 +81,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_brightness, torch.Tensor) @_register_kernel_internal(adjust_brightness, torch.Tensor)
@_register_kernel_internal(adjust_brightness, datapoints.Image) @_register_kernel_internal(adjust_brightness, datapoints.Image)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0: if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
...@@ -96,18 +96,18 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float ...@@ -96,18 +96,18 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
@_register_kernel_internal(adjust_brightness, PIL.Image.Image) @_register_kernel_internal(adjust_brightness, PIL.Image.Image)
def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image:
return _FP.adjust_brightness(image, brightness_factor=brightness_factor) return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, datapoints.Video) @_register_kernel_internal(adjust_brightness, datapoints.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) return adjust_brightness_image(video, brightness_factor=brightness_factor)
def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image(inpt, saturation_factor=saturation_factor)
_log_api_usage_once(adjust_saturation) _log_api_usage_once(adjust_saturation)
...@@ -117,7 +117,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten ...@@ -117,7 +117,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@_register_kernel_internal(adjust_saturation, torch.Tensor) @_register_kernel_internal(adjust_saturation, torch.Tensor)
@_register_kernel_internal(adjust_saturation, datapoints.Image) @_register_kernel_internal(adjust_saturation, datapoints.Image)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0: if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
...@@ -128,24 +128,24 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float ...@@ -128,24 +128,24 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if c == 1: # Match PIL behaviour if c == 1: # Match PIL behaviour
return image return image
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) grayscale_image = _rgb_to_grayscale_image(image, num_output_channels=1, preserve_dtype=False)
if not image.is_floating_point(): if not image.is_floating_point():
grayscale_image = grayscale_image.floor_() grayscale_image = grayscale_image.floor_()
return _blend(image, grayscale_image, saturation_factor) return _blend(image, grayscale_image, saturation_factor)
adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) _adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation)
@_register_kernel_internal(adjust_saturation, datapoints.Video) @_register_kernel_internal(adjust_saturation, datapoints.Video)
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) return adjust_saturation_image(video, saturation_factor=saturation_factor)
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image(inpt, contrast_factor=contrast_factor)
_log_api_usage_once(adjust_contrast) _log_api_usage_once(adjust_contrast)
...@@ -155,7 +155,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: ...@@ -155,7 +155,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@_register_kernel_internal(adjust_contrast, torch.Tensor) @_register_kernel_internal(adjust_contrast, torch.Tensor)
@_register_kernel_internal(adjust_contrast, datapoints.Image) @_register_kernel_internal(adjust_contrast, datapoints.Image)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0: if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
...@@ -164,7 +164,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -164,7 +164,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point() fp = image.is_floating_point()
if c == 3: if c == 3:
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) grayscale_image = _rgb_to_grayscale_image(image, num_output_channels=1, preserve_dtype=False)
if not fp: if not fp:
grayscale_image = grayscale_image.floor_() grayscale_image = grayscale_image.floor_()
else: else:
...@@ -173,17 +173,17 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -173,17 +173,17 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
return _blend(image, mean, contrast_factor) return _blend(image, mean, contrast_factor)
adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) _adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast)
@_register_kernel_internal(adjust_contrast, datapoints.Video) @_register_kernel_internal(adjust_contrast, datapoints.Video)
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) return adjust_contrast_image(video, contrast_factor=contrast_factor)
def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor)
_log_api_usage_once(adjust_sharpness) _log_api_usage_once(adjust_sharpness)
...@@ -193,7 +193,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso ...@@ -193,7 +193,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@_register_kernel_internal(adjust_sharpness, torch.Tensor) @_register_kernel_internal(adjust_sharpness, torch.Tensor)
@_register_kernel_internal(adjust_sharpness, datapoints.Image) @_register_kernel_internal(adjust_sharpness, datapoints.Image)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3): if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
...@@ -245,17 +245,17 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -245,17 +245,17 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
return output return output
adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) _adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness)
@_register_kernel_internal(adjust_sharpness, datapoints.Video) @_register_kernel_internal(adjust_sharpness, datapoints.Video)
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image(video, sharpness_factor=sharpness_factor)
def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image(inpt, hue_factor=hue_factor)
_log_api_usage_once(adjust_hue) _log_api_usage_once(adjust_hue)
...@@ -335,7 +335,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: ...@@ -335,7 +335,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(adjust_hue, torch.Tensor) @_register_kernel_internal(adjust_hue, torch.Tensor)
@_register_kernel_internal(adjust_hue, datapoints.Image) @_register_kernel_internal(adjust_hue, datapoints.Image)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
...@@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image return image
orig_dtype = image.dtype orig_dtype = image.dtype
image = to_dtype_image_tensor(image, torch.float32, scale=True) image = to_dtype_image(image, torch.float32, scale=True)
image = _rgb_to_hsv(image) image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3) h, s, v = image.unbind(dim=-3)
...@@ -359,20 +359,20 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -359,20 +359,20 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
image = torch.stack((h, s, v), dim=-3) image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image) image_hue_adj = _hsv_to_rgb(image)
return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) return to_dtype_image(image_hue_adj, orig_dtype, scale=True)
adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) _adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue)
@_register_kernel_internal(adjust_hue, datapoints.Video) @_register_kernel_internal(adjust_hue, datapoints.Video)
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor) return adjust_hue_image(video, hue_factor=hue_factor)
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image(inpt, gamma=gamma, gain=gain)
_log_api_usage_once(adjust_gamma) _log_api_usage_once(adjust_gamma)
...@@ -382,14 +382,14 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten ...@@ -382,14 +382,14 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@_register_kernel_internal(adjust_gamma, torch.Tensor) @_register_kernel_internal(adjust_gamma, torch.Tensor)
@_register_kernel_internal(adjust_gamma, datapoints.Image) @_register_kernel_internal(adjust_gamma, datapoints.Image)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0: if gamma < 0:
raise ValueError("Gamma should be a non-negative real number") raise ValueError("Gamma should be a non-negative real number")
# The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer).
# Since the gamma is non-negative, the output remains at [0, 1] scale. # Since the gamma is non-negative, the output remains at [0, 1] scale.
if not torch.is_floating_point(image): if not torch.is_floating_point(image):
output = to_dtype_image_tensor(image, torch.float32, scale=True).pow_(gamma) output = to_dtype_image(image, torch.float32, scale=True).pow_(gamma)
else: else:
output = image.pow(gamma) output = image.pow(gamma)
...@@ -398,20 +398,20 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 ...@@ -398,20 +398,20 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
# of the output can go beyond [0, 1]. # of the output can go beyond [0, 1].
output = output.mul_(gain).clamp_(0.0, 1.0) output = output.mul_(gain).clamp_(0.0, 1.0)
return to_dtype_image_tensor(output, image.dtype, scale=True) return to_dtype_image(output, image.dtype, scale=True)
adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) _adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma)
@_register_kernel_internal(adjust_gamma, datapoints.Video) @_register_kernel_internal(adjust_gamma, datapoints.Video)
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) return adjust_gamma_image(video, gamma=gamma, gain=gain)
def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return posterize_image_tensor(inpt, bits=bits) return posterize_image(inpt, bits=bits)
_log_api_usage_once(posterize) _log_api_usage_once(posterize)
...@@ -421,7 +421,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -421,7 +421,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, datapoints.Image) @_register_kernel_internal(posterize, datapoints.Image)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
levels = 1 << bits levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels) return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)
...@@ -434,17 +434,17 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -434,17 +434,17 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
return image & mask return image & mask
posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) _posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize)
@_register_kernel_internal(posterize, datapoints.Video) @_register_kernel_internal(posterize, datapoints.Video)
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits) return posterize_image(video, bits=bits)
def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image(inpt, threshold=threshold)
_log_api_usage_once(solarize) _log_api_usage_once(solarize)
...@@ -454,24 +454,24 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -454,24 +454,24 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@_register_kernel_internal(solarize, torch.Tensor) @_register_kernel_internal(solarize, torch.Tensor)
@_register_kernel_internal(solarize, datapoints.Image) @_register_kernel_internal(solarize, datapoints.Image)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype): if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
return torch.where(image >= threshold, invert_image_tensor(image), image) return torch.where(image >= threshold, invert_image(image), image)
solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) _solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize)
@_register_kernel_internal(solarize, datapoints.Video) @_register_kernel_internal(solarize, datapoints.Video)
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold) return solarize_image(video, threshold=threshold)
def autocontrast(inpt: torch.Tensor) -> torch.Tensor: def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return autocontrast_image_tensor(inpt) return autocontrast_image(inpt)
_log_api_usage_once(autocontrast) _log_api_usage_once(autocontrast)
...@@ -481,7 +481,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor: ...@@ -481,7 +481,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(autocontrast, torch.Tensor) @_register_kernel_internal(autocontrast, torch.Tensor)
@_register_kernel_internal(autocontrast, datapoints.Image) @_register_kernel_internal(autocontrast, datapoints.Image)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
...@@ -510,17 +510,17 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -510,17 +510,17 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype)
autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) _autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast)
@_register_kernel_internal(autocontrast, datapoints.Video) @_register_kernel_internal(autocontrast, datapoints.Video)
def autocontrast_video(video: torch.Tensor) -> torch.Tensor: def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video) return autocontrast_image(video)
def equalize(inpt: torch.Tensor) -> torch.Tensor: def equalize(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return equalize_image_tensor(inpt) return equalize_image(inpt)
_log_api_usage_once(equalize) _log_api_usage_once(equalize)
...@@ -530,7 +530,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor: ...@@ -530,7 +530,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(equalize, torch.Tensor) @_register_kernel_internal(equalize, torch.Tensor)
@_register_kernel_internal(equalize, datapoints.Image) @_register_kernel_internal(equalize, datapoints.Image)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: def equalize_image(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -545,7 +545,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -545,7 +545,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base. # by far the most common, we choose it as base.
output_dtype = image.dtype output_dtype = image.dtype
image = to_dtype_image_tensor(image, torch.uint8, scale=True) image = to_dtype_image(image, torch.uint8, scale=True)
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram. # corresponds to adding 1 to index 127 in the histogram.
...@@ -596,20 +596,20 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -596,20 +596,20 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
output = torch.where(valid_equalization, equalized_image, image) output = torch.where(valid_equalization, equalized_image, image)
return to_dtype_image_tensor(output, output_dtype, scale=True) return to_dtype_image(output, output_dtype, scale=True)
equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) _equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize)
@_register_kernel_internal(equalize, datapoints.Video) @_register_kernel_internal(equalize, datapoints.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor: def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video) return equalize_image(video)
def invert(inpt: torch.Tensor) -> torch.Tensor: def invert(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return invert_image_tensor(inpt) return invert_image(inpt)
_log_api_usage_once(invert) _log_api_usage_once(invert)
...@@ -619,7 +619,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor: ...@@ -619,7 +619,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(invert, torch.Tensor) @_register_kernel_internal(invert, torch.Tensor)
@_register_kernel_internal(invert, datapoints.Image) @_register_kernel_internal(invert, datapoints.Image)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: def invert_image(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
return 1.0 - image return 1.0 - image
elif image.dtype == torch.uint8: elif image.dtype == torch.uint8:
...@@ -629,12 +629,12 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -629,12 +629,12 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) _invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert)
@_register_kernel_internal(invert, datapoints.Video) @_register_kernel_internal(invert, datapoints.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor: def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video) return invert_image(video)
def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor:
...@@ -660,7 +660,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ...@@ -660,7 +660,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
ValueError: If ``len(permutation)`` doesn't match the number of channels in the input. ValueError: If ``len(permutation)`` doesn't match the number of channels in the input.
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return permute_channels_image_tensor(inpt, permutation=permutation) return permute_channels_image(inpt, permutation=permutation)
_log_api_usage_once(permute_channels) _log_api_usage_once(permute_channels)
...@@ -670,7 +670,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ...@@ -670,7 +670,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@_register_kernel_internal(permute_channels, torch.Tensor) @_register_kernel_internal(permute_channels, torch.Tensor)
@_register_kernel_internal(permute_channels, datapoints.Image) @_register_kernel_internal(permute_channels, datapoints.Image)
def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor:
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
...@@ -688,10 +688,10 @@ def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) - ...@@ -688,10 +688,10 @@ def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -
@_register_kernel_internal(permute_channels, PIL.Image.Image) @_register_kernel_internal(permute_channels, PIL.Image.Image)
def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image:
return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation)) return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
@_register_kernel_internal(permute_channels, datapoints.Video) @_register_kernel_internal(permute_channels, datapoints.Video)
def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor:
return permute_channels_image_tensor(video, permutation=permutation) return permute_channels_image(video, permutation=permutation)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment