Unverified Commit ca012d39 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make PIL kernels private (#7831)

parent cdbbd666
...@@ -228,12 +228,11 @@ Conversion ...@@ -228,12 +228,11 @@ Conversion
ToPILImage ToPILImage
v2.ToPILImage v2.ToPILImage
v2.ToImagePIL
ToTensor ToTensor
v2.ToTensor v2.ToTensor
PILToTensor PILToTensor
v2.PILToTensor v2.PILToTensor
v2.ToImageTensor v2.ToImage
ConvertImageDtype ConvertImageDtype
v2.ConvertImageDtype v2.ConvertImageDtype
v2.ToDtype v2.ToDtype
......
...@@ -27,7 +27,7 @@ def show(sample): ...@@ -27,7 +27,7 @@ def show(sample):
image, target = sample image, target = sample
if isinstance(image, PIL.Image.Image): if isinstance(image, PIL.Image.Image):
image = F.to_image_tensor(image) image = F.to_image(image)
image = F.to_dtype(image, torch.uint8, scale=True) image = F.to_dtype(image, torch.uint8, scale=True)
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
...@@ -101,7 +101,7 @@ transform = transforms.Compose( ...@@ -101,7 +101,7 @@ transform = transforms.Compose(
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}), transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(), transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(), transforms.ToImage(),
transforms.ConvertImageDtype(torch.float32), transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBoxes(), transforms.SanitizeBoundingBoxes(),
] ]
......
...@@ -33,7 +33,7 @@ class DetectionPresetTrain: ...@@ -33,7 +33,7 @@ class DetectionPresetTrain:
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "datapoint": if backend == "datapoint":
transforms.append(T.ToImageTensor()) transforms.append(T.ToImage())
elif backend == "tensor": elif backend == "tensor":
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
...@@ -71,7 +71,7 @@ class DetectionPresetTrain: ...@@ -71,7 +71,7 @@ class DetectionPresetTrain:
if backend == "pil": if backend == "pil":
# Note: we could just convert to pure tensors even in v2. # Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [T.ConvertImageDtype(torch.float)] transforms += [T.ConvertImageDtype(torch.float)]
...@@ -94,11 +94,11 @@ class DetectionPresetEval: ...@@ -94,11 +94,11 @@ class DetectionPresetEval:
backend = backend.lower() backend = backend.lower()
if backend == "pil": if backend == "pil":
# Note: we could just convert to pure tensors even in v2? # Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
elif backend == "tensor": elif backend == "tensor":
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
elif backend == "datapoint": elif backend == "datapoint":
transforms += [T.ToImageTensor()] transforms += [T.ToImage()]
else: else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
......
...@@ -32,7 +32,7 @@ class SegmentationPresetTrain: ...@@ -32,7 +32,7 @@ class SegmentationPresetTrain:
transforms = [] transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "datapoint": if backend == "datapoint":
transforms.append(T.ToImageTensor()) transforms.append(T.ToImage())
elif backend == "tensor": elif backend == "tensor":
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
elif backend != "pil": elif backend != "pil":
...@@ -81,7 +81,7 @@ class SegmentationPresetEval: ...@@ -81,7 +81,7 @@ class SegmentationPresetEval:
if backend == "tensor": if backend == "tensor":
transforms += [T.PILToTensor()] transforms += [T.PILToTensor()]
elif backend == "datapoint": elif backend == "datapoint":
transforms += [T.ToImageTensor()] transforms += [T.ToImage()]
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
...@@ -92,7 +92,7 @@ class SegmentationPresetEval: ...@@ -92,7 +92,7 @@ class SegmentationPresetEval:
if backend == "pil": if backend == "pil":
# Note: we could just convert to pure tensors even in v2? # Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [ transforms += [
T.ConvertImageDtype(torch.float), T.ConvertImageDtype(torch.float),
......
...@@ -27,7 +27,7 @@ from PIL import Image ...@@ -27,7 +27,7 @@ from PIL import Image
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import datapoints, io from torchvision import datapoints, io
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
...@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair): ...@@ -293,7 +293,7 @@ class ImagePair(TensorLikePair):
**other_parameters, **other_parameters,
): ):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = [to_image_tensor(input) for input in [actual, expected]] actual, expected = [to_image(input) for input in [actual, expected]]
super().__init__(actual, expected, **other_parameters) super().__init__(actual, expected, **other_parameters)
self.mae = mae self.mae = mae
...@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs): ...@@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs):
def make_image_pil(*args, **kwargs): def make_image_pil(*args, **kwargs):
return to_image_pil(make_image(*args, **kwargs)) return to_pil_image(make_image(*args, **kwargs))
def make_image_loader( def make_image_loader(
...@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation( ...@@ -609,12 +609,12 @@ def make_image_loader_for_interpolation(
) )
) )
image_tensor = to_image_tensor(image_pil) image_tensor = to_image(image_pil)
if memory_format == torch.contiguous_format: if memory_format == torch.contiguous_format:
image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True)
else: else:
image_tensor = image_tensor.to(device=device) image_tensor = image_tensor.to(device=device)
image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True) image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True)
return datapoints.Image(image_tensor) return datapoints.Image(image_tensor)
......
...@@ -17,7 +17,7 @@ from prototype_common_utils import make_label ...@@ -17,7 +17,7 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_simple_tensor from torchvision.transforms.v2.utils import check_type, is_simple_tensor
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -387,7 +387,7 @@ def test_fixed_sized_crop_against_detection_reference():
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
......
...@@ -666,19 +666,19 @@ class TestTransform: ...@@ -666,19 +666,19 @@ class TestTransform:
t(inpt) t(inpt)
class TestToImageTensor: class TestToImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch( fn = mocker.patch(
"torchvision.transforms.v2.functional.to_image_tensor", "torchvision.transforms.v2.functional.to_image",
return_value=torch.rand(1, 3, 8, 8), return_value=torch.rand(1, 3, 8, 8),
) )
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImageTensor() transform = transforms.ToImage()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int): if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
...@@ -686,30 +686,13 @@ class TestToImageTensor: ...@@ -686,30 +686,13 @@ class TestToImageTensor:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
class TestToImagePIL:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
transform(inpt)
if inpt_type in (datapoints.BoundingBoxes, PIL.Image.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
class TestToPILImage: class TestToPILImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage() transform = transforms.ToPILImage()
...@@ -1013,7 +996,7 @@ def test_antialias_warning(): ...@@ -1013,7 +996,7 @@ def test_antialias_warning():
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int)) @pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) @pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
...@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): ...@@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
@pytest.mark.parametrize("sanitize", (True, False)) @pytest.mark.parametrize("sanitize", (True, False))
def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str ...@@ -30,7 +30,7 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import query_size from torchvision.transforms.v2.utils import query_size
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
...@@ -630,7 +630,7 @@ def check_call_consistency( ...@@ -630,7 +630,7 @@ def check_call_consistency(
) )
if image.ndim == 3 and supports_pil: if image.ndim == 3 and supports_pil:
image_pil = to_image_pil(image) image_pil = to_pil_image(image)
try: try:
torch.manual_seed(0) torch.manual_seed(0)
...@@ -869,7 +869,7 @@ class TestToTensorTransforms: ...@@ -869,7 +869,7 @@ class TestToTensorTransforms:
legacy_transform = legacy_transforms.PILToTensor() legacy_transform = legacy_transforms.PILToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image) image_pil = to_pil_image(image)
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
...@@ -879,7 +879,7 @@ class TestToTensorTransforms: ...@@ -879,7 +879,7 @@ class TestToTensorTransforms:
legacy_transform = legacy_transforms.ToTensor() legacy_transform = legacy_transforms.ToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
image_pil = to_image_pil(image) image_pil = to_pil_image(image)
image_numpy = np.array(image_pil) image_numpy = np.array(image_pil)
assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil))
...@@ -1088,7 +1088,7 @@ class TestRefDetTransforms: ...@@ -1088,7 +1088,7 @@ class TestRefDetTransforms:
def make_label(extra_dims, categories): def make_label(extra_dims, categories):
return torch.randint(categories, extra_dims, dtype=torch.int64) return torch.randint(categories, extra_dims, dtype=torch.int64)
pil_image = to_image_pil(make_image(size=size, color_space="RGB")) pil_image = to_pil_image(make_image(size=size, color_space="RGB"))
target = { target = {
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -1192,7 +1192,7 @@ class TestRefSegTransforms: ...@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
conv_fns = [] conv_fns = []
if supports_pil: if supports_pil:
conv_fns.append(to_image_pil) conv_fns.append(to_pil_image)
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
...@@ -1201,8 +1201,8 @@ class TestRefSegTransforms: ...@@ -1201,8 +1201,8 @@ class TestRefSegTransforms:
dp = (conv_fn(datapoint_image), datapoint_mask) dp = (conv_fn(datapoint_image), datapoint_mask)
dp_ref = ( dp_ref = (
to_image_pil(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor), to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
to_image_pil(datapoint_mask), to_pil_image(datapoint_mask),
) )
yield dp, dp_ref yield dp, dp_ref
......
...@@ -280,12 +280,12 @@ class TestKernels: ...@@ -280,12 +280,12 @@ class TestKernels:
adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs) adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs)
actual = info.kernel( actual = info.kernel(
F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True), F.to_dtype_image(input, dtype=torch.float32, scale=True),
*adapted_other_args, *adapted_other_args,
**adapted_kwargs, **adapted_kwargs,
) )
expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True) expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True)
assert_close( assert_close(
actual, actual,
...@@ -377,7 +377,7 @@ class TestDispatchers: ...@@ -377,7 +377,7 @@ class TestDispatchers:
if image_datapoint.ndim > 3: if image_datapoint.ndim > 3:
pytest.skip("Input is batched") pytest.skip("Input is batched")
image_pil = F.to_image_pil(image_datapoint) image_pil = F.to_pil_image(image_datapoint)
output = info.dispatcher(image_pil, *other_args, **kwargs) output = info.dispatcher(image_pil, *other_args, **kwargs)
...@@ -470,7 +470,7 @@ class TestDispatchers: ...@@ -470,7 +470,7 @@ class TestDispatchers:
(F.hflip, F.horizontal_flip), (F.hflip, F.horizontal_flip),
(F.vflip, F.vertical_flip), (F.vflip, F.vertical_flip),
(F.get_image_num_channels, F.get_num_channels), (F.get_image_num_channels, F.get_num_channels),
(F.to_pil_image, F.to_image_pil), (F.to_pil_image, F.to_pil_image),
(F.elastic_transform, F.elastic), (F.elastic_transform, F.elastic),
(F.to_grayscale, F.rgb_to_grayscale), (F.to_grayscale, F.rgb_to_grayscale),
] ]
...@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels): ...@@ -493,7 +493,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
mean = image.mean(dim=(1, 2)).tolist() mean = image.mean(dim=(1, 2)).tolist()
std = image.std(dim=(1, 2)).tolist() std = image.std(dim=(1, 2)).tolist()
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) assert_samples_from_standard_normal(F.normalize_image(image, mean, std))
class TestClampBoundingBoxes: class TestClampBoundingBoxes:
...@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size): ...@@ -899,7 +899,7 @@ def test_correctness_center_crop_mask(device, output_size):
_, image_height, image_width = mask.shape _, image_height, image_width = mask.shape
if crop_width > image_height or crop_height > image_width: if crop_width > image_height or crop_height > image_width:
padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
mask = F.pad_image_tensor(mask, padding, fill=0) mask = F.pad_image(mask, padding, fill=0)
left = round((image_width - crop_width) * 0.5) left = round((image_width - crop_width) * 0.5)
top = round((image_height - crop_height) * 0.5) top = round((image_height - crop_height) * 0.5)
...@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size): ...@@ -920,7 +920,7 @@ def test_correctness_center_crop_mask(device, output_size):
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma): def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
fn = F.gaussian_blur_image_tensor fn = F.gaussian_blur_image
# true_cv2_results = { # true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) # # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
...@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, ...@@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
PIL.Image.new("RGB", (32, 32), 122), PIL.Image.new("RGB", (32, 32), 122),
], ],
) )
def test_to_image_tensor(inpt): def test_to_image(inpt):
output = F.to_image_tensor(inpt) output = F.to_image(inpt)
assert isinstance(output, torch.Tensor) assert isinstance(output, torch.Tensor)
assert output.shape == (3, 32, 32) assert output.shape == (3, 32, 32)
...@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt): ...@@ -993,8 +993,8 @@ def test_to_image_tensor(inpt):
], ],
) )
@pytest.mark.parametrize("mode", [None, "RGB"]) @pytest.mark.parametrize("mode", [None, "RGB"])
def test_to_image_pil(inpt, mode): def test_to_pil_image(inpt, mode):
output = F.to_image_pil(inpt, mode=mode) output = F.to_pil_image(inpt, mode=mode)
assert isinstance(output, PIL.Image.Image) assert isinstance(output, PIL.Image.Image)
assert np.asarray(inpt).sum() == np.asarray(output).sum() assert np.asarray(inpt).sum() == np.asarray(output).sum()
...@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode): ...@@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode):
def test_equalize_image_tensor_edge_cases(): def test_equalize_image_tensor_edge_cases():
inpt = torch.zeros(3, 200, 200, dtype=torch.uint8) inpt = torch.zeros(3, 200, 200, dtype=torch.uint8)
output = F.equalize_image_tensor(inpt) output = F.equalize_image(inpt)
torch.testing.assert_close(inpt, output) torch.testing.assert_close(inpt, output)
inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8) inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8)
inpt[..., 100:, 100:] = 1 inpt[..., 100:, 100:] = 1
output = F.equalize_image_tensor(inpt) output = F.equalize_image(inpt)
assert output.unique().tolist() == [0, 255] assert output.unique().tolist() == [0, 255]
...@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device): ...@@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device):
# TODO: We can remove this test and related torchvision workaround # TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor], [info for info in KERNEL_INFOS if info.kernel is F.resize_image],
args_kwargs_fn=lambda info: info.reference_inputs_fn(), args_kwargs_fn=lambda info: info.reference_inputs_fn(),
) )
def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs):
......
...@@ -437,7 +437,7 @@ class TestResize: ...@@ -437,7 +437,7 @@ class TestResize:
check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol) check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol)
check_kernel( check_kernel(
F.resize_image_tensor, F.resize_image,
make_image(self.INPUT_SIZE, dtype=dtype, device=device), make_image(self.INPUT_SIZE, dtype=dtype, device=device),
size=size, size=size,
interpolation=interpolation, interpolation=interpolation,
...@@ -495,9 +495,9 @@ class TestResize: ...@@ -495,9 +495,9 @@ class TestResize:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.resize_image_tensor, torch.Tensor), (F.resize_image, torch.Tensor),
(F.resize_image_pil, PIL.Image.Image), (F._resize_image_pil, PIL.Image.Image),
(F.resize_image_tensor, datapoints.Image), (F.resize_image, datapoints.Image),
(F.resize_bounding_boxes, datapoints.BoundingBoxes), (F.resize_bounding_boxes, datapoints.BoundingBoxes),
(F.resize_mask, datapoints.Mask), (F.resize_mask, datapoints.Mask),
(F.resize_video, datapoints.Video), (F.resize_video, datapoints.Video),
...@@ -541,9 +541,7 @@ class TestResize: ...@@ -541,9 +541,7 @@ class TestResize:
image = make_image(self.INPUT_SIZE, dtype=torch.uint8) image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True) actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True)
expected = F.to_image_tensor( expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg))
F.resize(F.to_image_pil(image), size=size, interpolation=interpolation, **max_size_kwarg)
)
self._check_output_size(image, actual, size=size, **max_size_kwarg) self._check_output_size(image, actual, size=size, **max_size_kwarg)
torch.testing.assert_close(actual, expected, atol=1, rtol=0) torch.testing.assert_close(actual, expected, atol=1, rtol=0)
...@@ -739,7 +737,7 @@ class TestHorizontalFlip: ...@@ -739,7 +737,7 @@ class TestHorizontalFlip:
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device): def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.horizontal_flip_image_tensor, make_image(dtype=dtype, device=device)) check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
...@@ -770,9 +768,9 @@ class TestHorizontalFlip: ...@@ -770,9 +768,9 @@ class TestHorizontalFlip:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.horizontal_flip_image_tensor, torch.Tensor), (F.horizontal_flip_image, torch.Tensor),
(F.horizontal_flip_image_pil, PIL.Image.Image), (F._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image_tensor, datapoints.Image), (F.horizontal_flip_image, datapoints.Image),
(F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes), (F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes),
(F.horizontal_flip_mask, datapoints.Mask), (F.horizontal_flip_mask, datapoints.Mask),
(F.horizontal_flip_video, datapoints.Video), (F.horizontal_flip_video, datapoints.Video),
...@@ -796,7 +794,7 @@ class TestHorizontalFlip: ...@@ -796,7 +794,7 @@ class TestHorizontalFlip:
image = make_image(dtype=torch.uint8, device="cpu") image = make_image(dtype=torch.uint8, device="cpu")
actual = fn(image) actual = fn(image)
expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image))) expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
...@@ -900,7 +898,7 @@ class TestAffine: ...@@ -900,7 +898,7 @@ class TestAffine:
if param == "fill": if param == "fill":
value = adapt_fill(value, dtype=dtype) value = adapt_fill(value, dtype=dtype)
self._check_kernel( self._check_kernel(
F.affine_image_tensor, F.affine_image,
make_image(dtype=dtype, device=device), make_image(dtype=dtype, device=device),
**{param: value}, **{param: value},
check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))), check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))),
...@@ -946,9 +944,9 @@ class TestAffine: ...@@ -946,9 +944,9 @@ class TestAffine:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.affine_image_tensor, torch.Tensor), (F.affine_image, torch.Tensor),
(F.affine_image_pil, PIL.Image.Image), (F._affine_image_pil, PIL.Image.Image),
(F.affine_image_tensor, datapoints.Image), (F.affine_image, datapoints.Image),
(F.affine_bounding_boxes, datapoints.BoundingBoxes), (F.affine_bounding_boxes, datapoints.BoundingBoxes),
(F.affine_mask, datapoints.Mask), (F.affine_mask, datapoints.Mask),
(F.affine_video, datapoints.Video), (F.affine_video, datapoints.Video),
...@@ -991,9 +989,9 @@ class TestAffine: ...@@ -991,9 +989,9 @@ class TestAffine:
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
) )
expected = F.to_image_tensor( expected = F.to_image(
F.affine( F.affine(
F.to_image_pil(image), F.to_pil_image(image),
angle=angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -1026,7 +1024,7 @@ class TestAffine: ...@@ -1026,7 +1024,7 @@ class TestAffine:
actual = transform(image) actual = transform(image)
torch.manual_seed(seed) torch.manual_seed(seed)
expected = F.to_image_tensor(transform(F.to_image_pil(image))) expected = F.to_image(transform(F.to_pil_image(image)))
mae = (actual.float() - expected.float()).abs().mean() mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
...@@ -1204,7 +1202,7 @@ class TestVerticalFlip: ...@@ -1204,7 +1202,7 @@ class TestVerticalFlip:
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device): def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.vertical_flip_image_tensor, make_image(dtype=dtype, device=device)) check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
...@@ -1235,9 +1233,9 @@ class TestVerticalFlip: ...@@ -1235,9 +1233,9 @@ class TestVerticalFlip:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.vertical_flip_image_tensor, torch.Tensor), (F.vertical_flip_image, torch.Tensor),
(F.vertical_flip_image_pil, PIL.Image.Image), (F._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image_tensor, datapoints.Image), (F.vertical_flip_image, datapoints.Image),
(F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes), (F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes),
(F.vertical_flip_mask, datapoints.Mask), (F.vertical_flip_mask, datapoints.Mask),
(F.vertical_flip_video, datapoints.Video), (F.vertical_flip_video, datapoints.Video),
...@@ -1259,7 +1257,7 @@ class TestVerticalFlip: ...@@ -1259,7 +1257,7 @@ class TestVerticalFlip:
image = make_image(dtype=torch.uint8, device="cpu") image = make_image(dtype=torch.uint8, device="cpu")
actual = fn(image) actual = fn(image)
expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image))) expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
...@@ -1339,7 +1337,7 @@ class TestRotate: ...@@ -1339,7 +1337,7 @@ class TestRotate:
if param != "angle": if param != "angle":
kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"]
check_kernel( check_kernel(
F.rotate_image_tensor, F.rotate_image,
make_image(dtype=dtype, device=device), make_image(dtype=dtype, device=device),
**kwargs, **kwargs,
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))), check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
...@@ -1385,9 +1383,9 @@ class TestRotate: ...@@ -1385,9 +1383,9 @@ class TestRotate:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.rotate_image_tensor, torch.Tensor), (F.rotate_image, torch.Tensor),
(F.rotate_image_pil, PIL.Image.Image), (F._rotate_image_pil, PIL.Image.Image),
(F.rotate_image_tensor, datapoints.Image), (F.rotate_image, datapoints.Image),
(F.rotate_bounding_boxes, datapoints.BoundingBoxes), (F.rotate_bounding_boxes, datapoints.BoundingBoxes),
(F.rotate_mask, datapoints.Mask), (F.rotate_mask, datapoints.Mask),
(F.rotate_video, datapoints.Video), (F.rotate_video, datapoints.Video),
...@@ -1419,9 +1417,9 @@ class TestRotate: ...@@ -1419,9 +1417,9 @@ class TestRotate:
fill = adapt_fill(fill, dtype=torch.uint8) fill = adapt_fill(fill, dtype=torch.uint8)
actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill)
expected = F.to_image_tensor( expected = F.to_image(
F.rotate( F.rotate(
F.to_image_pil(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill
) )
) )
...@@ -1452,7 +1450,7 @@ class TestRotate: ...@@ -1452,7 +1450,7 @@ class TestRotate:
actual = transform(image) actual = transform(image)
torch.manual_seed(seed) torch.manual_seed(seed)
expected = F.to_image_tensor(transform(F.to_image_pil(image))) expected = F.to_image(transform(F.to_pil_image(image)))
mae = (actual.float() - expected.float()).abs().mean() mae = (actual.float() - expected.float()).abs().mean()
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6 assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
...@@ -1621,8 +1619,8 @@ class TestToDtype: ...@@ -1621,8 +1619,8 @@ class TestToDtype:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.to_dtype_image_tensor, make_image_tensor), (F.to_dtype_image, make_image_tensor),
(F.to_dtype_image_tensor, make_image), (F.to_dtype_image, make_image),
(F.to_dtype_video, make_video), (F.to_dtype_video, make_video),
], ],
) )
...@@ -1801,7 +1799,7 @@ class TestAdjustBrightness: ...@@ -1801,7 +1799,7 @@ class TestAdjustBrightness:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.adjust_brightness_image_tensor, make_image), (F.adjust_brightness_image, make_image),
(F.adjust_brightness_video, make_video), (F.adjust_brightness_video, make_video),
], ],
) )
...@@ -1817,9 +1815,9 @@ class TestAdjustBrightness: ...@@ -1817,9 +1815,9 @@ class TestAdjustBrightness:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.adjust_brightness_image_tensor, torch.Tensor), (F.adjust_brightness_image, torch.Tensor),
(F.adjust_brightness_image_pil, PIL.Image.Image), (F._adjust_brightness_image_pil, PIL.Image.Image),
(F.adjust_brightness_image_tensor, datapoints.Image), (F.adjust_brightness_image, datapoints.Image),
(F.adjust_brightness_video, datapoints.Video), (F.adjust_brightness_video, datapoints.Video),
], ],
) )
...@@ -1831,7 +1829,7 @@ class TestAdjustBrightness: ...@@ -1831,7 +1829,7 @@ class TestAdjustBrightness:
image = make_image(dtype=torch.uint8, device="cpu") image = make_image(dtype=torch.uint8, device="cpu")
actual = F.adjust_brightness(image, brightness_factor=brightness_factor) actual = F.adjust_brightness(image, brightness_factor=brightness_factor)
expected = F.to_image_tensor(F.adjust_brightness(F.to_image_pil(image), brightness_factor=brightness_factor)) expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor))
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
...@@ -1979,9 +1977,9 @@ class TestShapeGetters: ...@@ -1979,9 +1977,9 @@ class TestShapeGetters:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.get_dimensions_image_tensor, make_image_tensor), (F.get_dimensions_image, make_image_tensor),
(F.get_dimensions_image_pil, make_image_pil), (F._get_dimensions_image_pil, make_image_pil),
(F.get_dimensions_image_tensor, make_image), (F.get_dimensions_image, make_image),
(F.get_dimensions_video, make_video), (F.get_dimensions_video, make_video),
], ],
) )
...@@ -1996,9 +1994,9 @@ class TestShapeGetters: ...@@ -1996,9 +1994,9 @@ class TestShapeGetters:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.get_num_channels_image_tensor, make_image_tensor), (F.get_num_channels_image, make_image_tensor),
(F.get_num_channels_image_pil, make_image_pil), (F._get_num_channels_image_pil, make_image_pil),
(F.get_num_channels_image_tensor, make_image), (F.get_num_channels_image, make_image),
(F.get_num_channels_video, make_video), (F.get_num_channels_video, make_video),
], ],
) )
...@@ -2012,9 +2010,9 @@ class TestShapeGetters: ...@@ -2012,9 +2010,9 @@ class TestShapeGetters:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.get_size_image_tensor, make_image_tensor), (F.get_size_image, make_image_tensor),
(F.get_size_image_pil, make_image_pil), (F._get_size_image_pil, make_image_pil),
(F.get_size_image_tensor, make_image), (F.get_size_image, make_image),
(F.get_size_bounding_boxes, make_bounding_box), (F.get_size_bounding_boxes, make_bounding_box),
(F.get_size_mask, make_detection_mask), (F.get_size_mask, make_detection_mask),
(F.get_size_mask, make_segmentation_mask), (F.get_size_mask, make_segmentation_mask),
...@@ -2101,7 +2099,7 @@ class TestRegisterKernel: ...@@ -2101,7 +2099,7 @@ class TestRegisterKernel:
F.register_kernel(F.resize, object) F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"): with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) F.register_kernel(F.resize, datapoints.Image)(F.resize_image)
class CustomDatapoint(datapoints.Datapoint): class CustomDatapoint(datapoints.Datapoint):
pass pass
...@@ -2119,9 +2117,9 @@ class TestGetKernel: ...@@ -2119,9 +2117,9 @@ class TestGetKernel:
# We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination
# would also be fine # would also be fine
KERNELS = { KERNELS = {
torch.Tensor: F.resize_image_tensor, torch.Tensor: F.resize_image,
PIL.Image.Image: F.resize_image_pil, PIL.Image.Image: F._resize_image_pil,
datapoints.Image: F.resize_image_tensor, datapoints.Image: F.resize_image,
datapoints.BoundingBoxes: F.resize_bounding_boxes, datapoints.BoundingBoxes: F.resize_bounding_boxes,
datapoints.Mask: F.resize_mask, datapoints.Mask: F.resize_mask,
datapoints.Video: F.resize_video, datapoints.Video: F.resize_video,
...@@ -2217,10 +2215,10 @@ class TestPermuteChannels: ...@@ -2217,10 +2215,10 @@ class TestPermuteChannels:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "make_input"), ("kernel", "make_input"),
[ [
(F.permute_channels_image_tensor, make_image_tensor), (F.permute_channels_image, make_image_tensor),
# FIXME # FIXME
# check_kernel does not support PIL kernel, but it should # check_kernel does not support PIL kernel, but it should
(F.permute_channels_image_tensor, make_image), (F.permute_channels_image, make_image),
(F.permute_channels_video, make_video), (F.permute_channels_video, make_video),
], ],
) )
...@@ -2236,9 +2234,9 @@ class TestPermuteChannels: ...@@ -2236,9 +2234,9 @@ class TestPermuteChannels:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel", "input_type"), ("kernel", "input_type"),
[ [
(F.permute_channels_image_tensor, torch.Tensor), (F.permute_channels_image, torch.Tensor),
(F.permute_channels_image_pil, PIL.Image.Image), (F._permute_channels_image_pil, PIL.Image.Image),
(F.permute_channels_image_tensor, datapoints.Image), (F.permute_channels_image, datapoints.Image),
(F.permute_channels_video, datapoints.Video), (F.permute_channels_video, datapoints.Video),
], ],
) )
......
...@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils ...@@ -7,7 +7,7 @@ import torchvision.transforms.v2.utils
from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import has_all, has_any from torchvision.transforms.v2.utils import has_all, has_any
...@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE) ...@@ -44,7 +44,7 @@ MASK = make_detection_mask(DEFAULT_SIZE)
True, True,
), ),
( (
(to_image_pil(IMAGE),), (to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True, True,
), ),
......
...@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [ ...@@ -142,32 +142,32 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.crop, F.crop,
kernels={ kernels={
datapoints.Image: F.crop_image_tensor, datapoints.Image: F.crop_image,
datapoints.Video: F.crop_video, datapoints.Video: F.crop_video,
datapoints.BoundingBoxes: F.crop_bounding_boxes, datapoints.BoundingBoxes: F.crop_bounding_boxes,
datapoints.Mask: F.crop_mask, datapoints.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
kernels={ kernels={
datapoints.Image: F.resized_crop_image_tensor, datapoints.Image: F.resized_crop_image,
datapoints.Video: F.resized_crop_video, datapoints.Video: F.resized_crop_video,
datapoints.BoundingBoxes: F.resized_crop_bounding_boxes, datapoints.BoundingBoxes: F.resized_crop_bounding_boxes,
datapoints.Mask: F.resized_crop_mask, datapoints.Mask: F.resized_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.pad, F.pad,
kernels={ kernels={
datapoints.Image: F.pad_image_tensor, datapoints.Image: F.pad_image,
datapoints.Video: F.pad_video, datapoints.Video: F.pad_video,
datapoints.BoundingBoxes: F.pad_bounding_boxes, datapoints.BoundingBoxes: F.pad_bounding_boxes,
datapoints.Mask: F.pad_mask, datapoints.Mask: F.pad_mask,
}, },
pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
*xfails_pil( *xfails_pil(
reason=( reason=(
...@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [ ...@@ -184,12 +184,12 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.perspective, F.perspective,
kernels={ kernels={
datapoints.Image: F.perspective_image_tensor, datapoints.Image: F.perspective_image,
datapoints.Video: F.perspective_video, datapoints.Video: F.perspective_video,
datapoints.BoundingBoxes: F.perspective_bounding_boxes, datapoints.BoundingBoxes: F.perspective_bounding_boxes,
datapoints.Mask: F.perspective_mask, datapoints.Mask: F.perspective_mask,
}, },
pil_kernel_info=PILKernelInfo(F.perspective_image_pil), pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[ test_marks=[
*xfails_pil_if_fill_sequence_needs_broadcast, *xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"), xfail_jit_python_scalar_arg("fill"),
...@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [ ...@@ -198,23 +198,23 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.elastic, F.elastic,
kernels={ kernels={
datapoints.Image: F.elastic_image_tensor, datapoints.Image: F.elastic_image,
datapoints.Video: F.elastic_video, datapoints.Video: F.elastic_video,
datapoints.BoundingBoxes: F.elastic_bounding_boxes, datapoints.BoundingBoxes: F.elastic_bounding_boxes,
datapoints.Mask: F.elastic_mask, datapoints.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F.elastic_image_pil), pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")], test_marks=[xfail_jit_python_scalar_arg("fill")],
), ),
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
datapoints.Image: F.center_crop_image_tensor, datapoints.Image: F.center_crop_image,
datapoints.Video: F.center_crop_video, datapoints.Video: F.center_crop_video,
datapoints.BoundingBoxes: F.center_crop_bounding_boxes, datapoints.BoundingBoxes: F.center_crop_bounding_boxes,
datapoints.Mask: F.center_crop_mask, datapoints.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("output_size"), xfail_jit_python_scalar_arg("output_size"),
], ],
...@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [ ...@@ -222,10 +222,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.gaussian_blur, F.gaussian_blur,
kernels={ kernels={
datapoints.Image: F.gaussian_blur_image_tensor, datapoints.Image: F.gaussian_blur_image,
datapoints.Video: F.gaussian_blur_video, datapoints.Video: F.gaussian_blur_video,
}, },
pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"), xfail_jit_python_scalar_arg("sigma"),
...@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [ ...@@ -234,58 +234,58 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.equalize, F.equalize,
kernels={ kernels={
datapoints.Image: F.equalize_image_tensor, datapoints.Image: F.equalize_image,
datapoints.Video: F.equalize_video, datapoints.Video: F.equalize_video,
}, },
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"), pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.invert, F.invert,
kernels={ kernels={
datapoints.Image: F.invert_image_tensor, datapoints.Image: F.invert_image,
datapoints.Video: F.invert_video, datapoints.Video: F.invert_video,
}, },
pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"), pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.posterize, F.posterize,
kernels={ kernels={
datapoints.Image: F.posterize_image_tensor, datapoints.Image: F.posterize_image,
datapoints.Video: F.posterize_video, datapoints.Video: F.posterize_video,
}, },
pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"), pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.solarize, F.solarize,
kernels={ kernels={
datapoints.Image: F.solarize_image_tensor, datapoints.Image: F.solarize_image,
datapoints.Video: F.solarize_video, datapoints.Video: F.solarize_video,
}, },
pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"), pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.autocontrast, F.autocontrast,
kernels={ kernels={
datapoints.Image: F.autocontrast_image_tensor, datapoints.Image: F.autocontrast_image,
datapoints.Video: F.autocontrast_video, datapoints.Video: F.autocontrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"), pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_sharpness, F.adjust_sharpness,
kernels={ kernels={
datapoints.Image: F.adjust_sharpness_image_tensor, datapoints.Image: F.adjust_sharpness_image,
datapoints.Video: F.adjust_sharpness_video, datapoints.Video: F.adjust_sharpness_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.erase, F.erase,
kernels={ kernels={
datapoints.Image: F.erase_image_tensor, datapoints.Image: F.erase_image,
datapoints.Video: F.erase_video, datapoints.Video: F.erase_video,
}, },
pil_kernel_info=PILKernelInfo(F.erase_image_pil), pil_kernel_info=PILKernelInfo(F._erase_image_pil),
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
...@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [ ...@@ -293,42 +293,42 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.adjust_contrast, F.adjust_contrast,
kernels={ kernels={
datapoints.Image: F.adjust_contrast_image_tensor, datapoints.Image: F.adjust_contrast_image,
datapoints.Video: F.adjust_contrast_video, datapoints.Video: F.adjust_contrast_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_gamma, F.adjust_gamma,
kernels={ kernels={
datapoints.Image: F.adjust_gamma_image_tensor, datapoints.Image: F.adjust_gamma_image,
datapoints.Video: F.adjust_gamma_video, datapoints.Video: F.adjust_gamma_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_hue, F.adjust_hue,
kernels={ kernels={
datapoints.Image: F.adjust_hue_image_tensor, datapoints.Image: F.adjust_hue_image,
datapoints.Video: F.adjust_hue_video, datapoints.Video: F.adjust_hue_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_saturation, F.adjust_saturation,
kernels={ kernels={
datapoints.Image: F.adjust_saturation_image_tensor, datapoints.Image: F.adjust_saturation_image,
datapoints.Video: F.adjust_saturation_video, datapoints.Video: F.adjust_saturation_video,
}, },
pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.five_crop, F.five_crop,
kernels={ kernels={
datapoints.Image: F.five_crop_image_tensor, datapoints.Image: F.five_crop_image,
datapoints.Video: F.five_crop_video, datapoints.Video: F.five_crop_video,
}, },
pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
*multi_crop_skips, *multi_crop_skips,
...@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [ ...@@ -337,19 +337,19 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.ten_crop, F.ten_crop,
kernels={ kernels={
datapoints.Image: F.ten_crop_image_tensor, datapoints.Image: F.ten_crop_image,
datapoints.Video: F.ten_crop_video, datapoints.Video: F.ten_crop_video,
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
*multi_crop_skips, *multi_crop_skips,
], ],
pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.normalize, F.normalize,
kernels={ kernels={
datapoints.Image: F.normalize_image_tensor, datapoints.Image: F.normalize_image,
datapoints.Video: F.normalize_video, datapoints.Video: F.normalize_video,
}, },
test_marks=[ test_marks=[
......
...@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel): ...@@ -122,12 +122,12 @@ def pil_reference_wrapper(pil_kernel):
f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}" f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}"
) )
input_pil = F.to_image_pil(input_tensor) input_pil = F.to_pil_image(input_tensor)
output_pil = pil_kernel(input_pil, *other_args, **kwargs) output_pil = pil_kernel(input_pil, *other_args, **kwargs)
if not isinstance(output_pil, PIL.Image.Image): if not isinstance(output_pil, PIL.Image.Image):
return output_pil return output_pil
output_tensor = F.to_image_tensor(output_pil) output_tensor = F.to_image(output_pil)
# 2D mask shenanigans # 2D mask shenanigans
if output_tensor.ndim == 2 and input_tensor.ndim == 3: if output_tensor.ndim == 2 and input_tensor.ndim == 3:
...@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes(): ...@@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.crop_image_tensor, F.crop_image,
kernel_name="crop_image_tensor", kernel_name="crop_image_tensor",
sample_inputs_fn=sample_inputs_crop_image_tensor, sample_inputs_fn=sample_inputs_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F._crop_image_pil),
reference_inputs_fn=reference_inputs_crop_image_tensor, reference_inputs_fn=reference_inputs_crop_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
...@@ -347,7 +347,7 @@ KERNEL_INFOS.extend( ...@@ -347,7 +347,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.crop_mask, F.crop_mask,
sample_inputs_fn=sample_inputs_crop_mask, sample_inputs_fn=sample_inputs_crop_mask,
reference_fn=pil_reference_wrapper(F.crop_image_pil), reference_fn=pil_reference_wrapper(F._crop_image_pil),
reference_inputs_fn=reference_inputs_crop_mask, reference_inputs_fn=reference_inputs_crop_mask,
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
...@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs): ...@@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs):
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
}: }:
raise pytest.UsageError("Anti-aliasing is always active in PIL") raise pytest.UsageError("Anti-aliasing is always active in PIL")
return F.resized_crop_image_pil(*args, **kwargs) return F._resized_crop_image_pil(*args, **kwargs)
def reference_inputs_resized_crop_image_tensor(): def reference_inputs_resized_crop_image_tensor():
...@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video(): ...@@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.resized_crop_image_tensor, F.resized_crop_image,
sample_inputs_fn=sample_inputs_resized_crop_image_tensor, sample_inputs_fn=sample_inputs_resized_crop_image_tensor,
reference_fn=reference_resized_crop_image_tensor, reference_fn=reference_resized_crop_image_tensor,
reference_inputs_fn=reference_inputs_resized_crop_image_tensor, reference_inputs_fn=reference_inputs_resized_crop_image_tensor,
...@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs): ...@@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs):
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.pad_image_tensor, F.pad_image,
sample_inputs_fn=sample_inputs_pad_image_tensor, sample_inputs_fn=sample_inputs_pad_image_tensor,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F._pad_image_pil),
reference_inputs_fn=reference_inputs_pad_image_tensor, reference_inputs_fn=reference_inputs_pad_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
...@@ -595,7 +595,7 @@ KERNEL_INFOS.extend( ...@@ -595,7 +595,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.pad_mask, F.pad_mask,
sample_inputs_fn=sample_inputs_pad_mask, sample_inputs_fn=sample_inputs_pad_mask,
reference_fn=pil_reference_wrapper(F.pad_image_pil), reference_fn=pil_reference_wrapper(F._pad_image_pil),
reference_inputs_fn=reference_inputs_pad_mask, reference_inputs_fn=reference_inputs_pad_mask,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
), ),
...@@ -690,9 +690,9 @@ def sample_inputs_perspective_video(): ...@@ -690,9 +690,9 @@ def sample_inputs_perspective_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.perspective_image_tensor, F.perspective_image,
sample_inputs_fn=sample_inputs_perspective_image_tensor, sample_inputs_fn=sample_inputs_perspective_image_tensor,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F._perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor, reference_inputs_fn=reference_inputs_perspective_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
closeness_kwargs={ closeness_kwargs={
...@@ -715,7 +715,7 @@ KERNEL_INFOS.extend( ...@@ -715,7 +715,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.perspective_mask, F.perspective_mask,
sample_inputs_fn=sample_inputs_perspective_mask, sample_inputs_fn=sample_inputs_perspective_mask,
reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_fn=pil_reference_wrapper(F._perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask, reference_inputs_fn=reference_inputs_perspective_mask,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -786,7 +786,7 @@ def sample_inputs_elastic_video(): ...@@ -786,7 +786,7 @@ def sample_inputs_elastic_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.elastic_image_tensor, F.elastic_image,
sample_inputs_fn=sample_inputs_elastic_image_tensor, sample_inputs_fn=sample_inputs_elastic_image_tensor,
reference_inputs_fn=reference_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor,
float32_vs_uint8=float32_vs_uint8_fill_adapter, float32_vs_uint8=float32_vs_uint8_fill_adapter,
...@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video(): ...@@ -870,9 +870,9 @@ def sample_inputs_center_crop_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.center_crop_image_tensor, F.center_crop_image,
sample_inputs_fn=sample_inputs_center_crop_image_tensor, sample_inputs_fn=sample_inputs_center_crop_image_tensor,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_image_tensor, reference_inputs_fn=reference_inputs_center_crop_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
test_marks=[ test_marks=[
...@@ -889,7 +889,7 @@ KERNEL_INFOS.extend( ...@@ -889,7 +889,7 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.center_crop_mask, F.center_crop_mask,
sample_inputs_fn=sample_inputs_center_crop_mask, sample_inputs_fn=sample_inputs_center_crop_mask,
reference_fn=pil_reference_wrapper(F.center_crop_image_pil), reference_fn=pil_reference_wrapper(F._center_crop_image_pil),
reference_inputs_fn=reference_inputs_center_crop_mask, reference_inputs_fn=reference_inputs_center_crop_mask,
float32_vs_uint8=True, float32_vs_uint8=True,
test_marks=[ test_marks=[
...@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video(): ...@@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.gaussian_blur_image_tensor, F.gaussian_blur_image,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=cuda_vs_cpu_pixel_difference(), closeness_kwargs=cuda_vs_cpu_pixel_difference(),
test_marks=[ test_marks=[
...@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video(): ...@@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.equalize_image_tensor, F.equalize_image,
kernel_name="equalize_image_tensor", kernel_name="equalize_image_tensor",
sample_inputs_fn=sample_inputs_equalize_image_tensor, sample_inputs_fn=sample_inputs_equalize_image_tensor,
reference_fn=pil_reference_wrapper(F.equalize_image_pil), reference_fn=pil_reference_wrapper(F._equalize_image_pil),
float32_vs_uint8=True, float32_vs_uint8=True,
reference_inputs_fn=reference_inputs_equalize_image_tensor, reference_inputs_fn=reference_inputs_equalize_image_tensor,
), ),
...@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video(): ...@@ -1043,10 +1043,10 @@ def sample_inputs_invert_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.invert_image_tensor, F.invert_image,
kernel_name="invert_image_tensor", kernel_name="invert_image_tensor",
sample_inputs_fn=sample_inputs_invert_image_tensor, sample_inputs_fn=sample_inputs_invert_image_tensor,
reference_fn=pil_reference_wrapper(F.invert_image_pil), reference_fn=pil_reference_wrapper(F._invert_image_pil),
reference_inputs_fn=reference_inputs_invert_image_tensor, reference_inputs_fn=reference_inputs_invert_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
), ),
...@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video(): ...@@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.posterize_image_tensor, F.posterize_image,
kernel_name="posterize_image_tensor", kernel_name="posterize_image_tensor",
sample_inputs_fn=sample_inputs_posterize_image_tensor, sample_inputs_fn=sample_inputs_posterize_image_tensor,
reference_fn=pil_reference_wrapper(F.posterize_image_pil), reference_fn=pil_reference_wrapper(F._posterize_image_pil),
reference_inputs_fn=reference_inputs_posterize_image_tensor, reference_inputs_fn=reference_inputs_posterize_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
...@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video(): ...@@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.solarize_image_tensor, F.solarize_image,
kernel_name="solarize_image_tensor", kernel_name="solarize_image_tensor",
sample_inputs_fn=sample_inputs_solarize_image_tensor, sample_inputs_fn=sample_inputs_solarize_image_tensor,
reference_fn=pil_reference_wrapper(F.solarize_image_pil), reference_fn=pil_reference_wrapper(F._solarize_image_pil),
reference_inputs_fn=reference_inputs_solarize_image_tensor, reference_inputs_fn=reference_inputs_solarize_image_tensor,
float32_vs_uint8=uint8_to_float32_threshold_adapter, float32_vs_uint8=uint8_to_float32_threshold_adapter,
closeness_kwargs=float32_vs_uint8_pixel_difference(), closeness_kwargs=float32_vs_uint8_pixel_difference(),
...@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video(): ...@@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.autocontrast_image_tensor, F.autocontrast_image,
kernel_name="autocontrast_image_tensor", kernel_name="autocontrast_image_tensor",
sample_inputs_fn=sample_inputs_autocontrast_image_tensor, sample_inputs_fn=sample_inputs_autocontrast_image_tensor,
reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), reference_fn=pil_reference_wrapper(F._autocontrast_image_pil),
reference_inputs_fn=reference_inputs_autocontrast_image_tensor, reference_inputs_fn=reference_inputs_autocontrast_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video(): ...@@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_sharpness_image_tensor, F.adjust_sharpness_image,
kernel_name="adjust_sharpness_image_tensor", kernel_name="adjust_sharpness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), reference_fn=pil_reference_wrapper(F._adjust_sharpness_image_pil),
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(2), closeness_kwargs=float32_vs_uint8_pixel_difference(2),
...@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video(): ...@@ -1241,7 +1241,7 @@ def sample_inputs_erase_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.erase_image_tensor, F.erase_image,
kernel_name="erase_image_tensor", kernel_name="erase_image_tensor",
sample_inputs_fn=sample_inputs_erase_image_tensor, sample_inputs_fn=sample_inputs_erase_image_tensor,
), ),
...@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video(): ...@@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_contrast_image_tensor, F.adjust_contrast_image,
kernel_name="adjust_contrast_image_tensor", kernel_name="adjust_contrast_image_tensor",
sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), reference_fn=pil_reference_wrapper(F._adjust_contrast_image_pil),
reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video(): ...@@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_gamma_image_tensor, F.adjust_gamma_image,
kernel_name="adjust_gamma_image_tensor", kernel_name="adjust_gamma_image_tensor",
sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), reference_fn=pil_reference_wrapper(F._adjust_gamma_image_pil),
reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video(): ...@@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_hue_image_tensor, F.adjust_hue_image,
kernel_name="adjust_hue_image_tensor", kernel_name="adjust_hue_image_tensor",
sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, sample_inputs_fn=sample_inputs_adjust_hue_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), reference_fn=pil_reference_wrapper(F._adjust_hue_image_pil),
reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, reference_inputs_fn=reference_inputs_adjust_hue_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video(): ...@@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.adjust_saturation_image_tensor, F.adjust_saturation_image,
kernel_name="adjust_saturation_image_tensor", kernel_name="adjust_saturation_image_tensor",
sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), reference_fn=pil_reference_wrapper(F._adjust_saturation_image_pil),
reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
closeness_kwargs={ closeness_kwargs={
...@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel): ...@@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel):
def wrapper(input_tensor, *other_args, **kwargs): def wrapper(input_tensor, *other_args, **kwargs):
output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs)
return type(output)( return type(output)(
F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True) F.to_dtype_image(F.to_image(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output
for output_pil in output
) )
return wrapper return wrapper
...@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [ ...@@ -1532,9 +1531,9 @@ _common_five_ten_crop_marks = [
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.five_crop_image_tensor, F.five_crop_image,
sample_inputs_fn=sample_inputs_five_crop_image_tensor, sample_inputs_fn=sample_inputs_five_crop_image_tensor,
reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil), reference_fn=multi_crop_pil_reference_wrapper(F._five_crop_image_pil),
reference_inputs_fn=reference_inputs_five_crop_image_tensor, reference_inputs_fn=reference_inputs_five_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
), ),
...@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend( ...@@ -1544,9 +1543,9 @@ KERNEL_INFOS.extend(
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
), ),
KernelInfo( KernelInfo(
F.ten_crop_image_tensor, F.ten_crop_image,
sample_inputs_fn=sample_inputs_ten_crop_image_tensor, sample_inputs_fn=sample_inputs_ten_crop_image_tensor,
reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil), reference_fn=multi_crop_pil_reference_wrapper(F._ten_crop_image_pil),
reference_inputs_fn=reference_inputs_ten_crop_image_tensor, reference_inputs_fn=reference_inputs_ten_crop_image_tensor,
test_marks=_common_five_ten_crop_marks, test_marks=_common_five_ten_crop_marks,
), ),
...@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video(): ...@@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video():
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
F.normalize_image_tensor, F.normalize_image,
kernel_name="normalize_image_tensor", kernel_name="normalize_image_tensor",
sample_inputs_fn=sample_inputs_normalize_image_tensor, sample_inputs_fn=sample_inputs_normalize_image_tensor,
reference_fn=reference_normalize_image_tensor, reference_fn=reference_normalize_image_tensor,
......
...@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform): ...@@ -112,7 +112,7 @@ class SimpleCopyPaste(Transform):
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
images.append(obj) images.append(obj)
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image_tensor(obj)) images.append(F.to_image(obj))
elif isinstance(obj, datapoints.BoundingBoxes): elif isinstance(obj, datapoints.BoundingBoxes):
bboxes.append(obj) bboxes.append(obj)
elif isinstance(obj, datapoints.Mask): elif isinstance(obj, datapoints.Mask):
...@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform): ...@@ -144,7 +144,7 @@ class SimpleCopyPaste(Transform):
flat_sample[i] = datapoints.wrap(output_images[c0], like=obj) flat_sample[i] = datapoints.wrap(output_images[c0], like=obj)
c0 += 1 c0 += 1
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0]) flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1 c0 += 1
elif is_simple_tensor(obj): elif is_simple_tensor(obj):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
......
...@@ -52,7 +52,7 @@ from ._misc import ( ...@@ -52,7 +52,7 @@ from ._misc import (
ToDtype, ToDtype,
) )
from ._temporal import UniformTemporalSubsample from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._type_conversion import PILToTensor, ToImage, ToPILImage
from ._deprecated import ToTensor # usort: skip from ._deprecated import ToTensor # usort: skip
......
...@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase): ...@@ -622,6 +622,6 @@ class AugMix(_AutoAugmentBase):
if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = datapoints.wrap(mix, like=orig_image_or_video) mix = datapoints.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_pil_image(mix)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
...@@ -26,7 +26,7 @@ class PILToTensor(Transform): ...@@ -26,7 +26,7 @@ class PILToTensor(Transform):
return F.pil_to_tensor(inpt) return F.pil_to_tensor(inpt)
class ToImageTensor(Transform): class ToImage(Transform):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image` """[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image`
; this does not scale values. ; this does not scale values.
...@@ -40,10 +40,10 @@ class ToImageTensor(Transform): ...@@ -40,10 +40,10 @@ class ToImageTensor(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> datapoints.Image: ) -> datapoints.Image:
return F.to_image_tensor(inpt) return F.to_image(inpt)
class ToImagePIL(Transform): class ToPILImage(Transform):
"""[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values. """[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values.
.. v2betastatus:: ToImagePIL transform .. v2betastatus:: ToImagePIL transform
...@@ -74,9 +74,4 @@ class ToImagePIL(Transform): ...@@ -74,9 +74,4 @@ class ToImagePIL(Transform):
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image: ) -> PIL.Image.Image:
return F.to_image_pil(inpt, mode=self.mode) return F.to_pil_image(inpt, mode=self.mode)
# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
ToPILImage = ToImagePIL
...@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip ...@@ -5,173 +5,173 @@ from ._utils import is_simple_tensor, register_kernel # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_boxes, clamp_bounding_boxes,
convert_format_bounding_boxes, convert_format_bounding_boxes,
get_dimensions_image_tensor, get_dimensions_image,
get_dimensions_image_pil, _get_dimensions_image_pil,
get_dimensions_video, get_dimensions_video,
get_dimensions, get_dimensions,
get_num_frames_video, get_num_frames_video,
get_num_frames, get_num_frames,
get_image_num_channels, get_image_num_channels,
get_num_channels_image_tensor, get_num_channels_image,
get_num_channels_image_pil, _get_num_channels_image_pil,
get_num_channels_video, get_num_channels_video,
get_num_channels, get_num_channels,
get_size_bounding_boxes, get_size_bounding_boxes,
get_size_image_tensor, get_size_image,
get_size_image_pil, _get_size_image_pil,
get_size_mask, get_size_mask,
get_size_video, get_size_video,
get_size, get_size,
) # usort: skip ) # usort: skip
from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video from ._augment import _erase_image_pil, erase, erase_image, erase_video
from ._color import ( from ._color import (
_adjust_brightness_image_pil,
_adjust_contrast_image_pil,
_adjust_gamma_image_pil,
_adjust_hue_image_pil,
_adjust_saturation_image_pil,
_adjust_sharpness_image_pil,
_autocontrast_image_pil,
_equalize_image_pil,
_invert_image_pil,
_permute_channels_image_pil,
_posterize_image_pil,
_rgb_to_grayscale_image_pil,
_solarize_image_pil,
adjust_brightness, adjust_brightness,
adjust_brightness_image_pil, adjust_brightness_image,
adjust_brightness_image_tensor,
adjust_brightness_video, adjust_brightness_video,
adjust_contrast, adjust_contrast,
adjust_contrast_image_pil, adjust_contrast_image,
adjust_contrast_image_tensor,
adjust_contrast_video, adjust_contrast_video,
adjust_gamma, adjust_gamma,
adjust_gamma_image_pil, adjust_gamma_image,
adjust_gamma_image_tensor,
adjust_gamma_video, adjust_gamma_video,
adjust_hue, adjust_hue,
adjust_hue_image_pil, adjust_hue_image,
adjust_hue_image_tensor,
adjust_hue_video, adjust_hue_video,
adjust_saturation, adjust_saturation,
adjust_saturation_image_pil, adjust_saturation_image,
adjust_saturation_image_tensor,
adjust_saturation_video, adjust_saturation_video,
adjust_sharpness, adjust_sharpness,
adjust_sharpness_image_pil, adjust_sharpness_image,
adjust_sharpness_image_tensor,
adjust_sharpness_video, adjust_sharpness_video,
autocontrast, autocontrast,
autocontrast_image_pil, autocontrast_image,
autocontrast_image_tensor,
autocontrast_video, autocontrast_video,
equalize, equalize,
equalize_image_pil, equalize_image,
equalize_image_tensor,
equalize_video, equalize_video,
invert, invert,
invert_image_pil, invert_image,
invert_image_tensor,
invert_video, invert_video,
permute_channels, permute_channels,
permute_channels_image_pil, permute_channels_image,
permute_channels_image_tensor,
permute_channels_video, permute_channels_video,
posterize, posterize,
posterize_image_pil, posterize_image,
posterize_image_tensor,
posterize_video, posterize_video,
rgb_to_grayscale, rgb_to_grayscale,
rgb_to_grayscale_image_pil, rgb_to_grayscale_image,
rgb_to_grayscale_image_tensor,
solarize, solarize,
solarize_image_pil, solarize_image,
solarize_image_tensor,
solarize_video, solarize_video,
to_grayscale, to_grayscale,
) )
from ._geometry import ( from ._geometry import (
_affine_image_pil,
_center_crop_image_pil,
_crop_image_pil,
_elastic_image_pil,
_five_crop_image_pil,
_horizontal_flip_image_pil,
_pad_image_pil,
_perspective_image_pil,
_resize_image_pil,
_resized_crop_image_pil,
_rotate_image_pil,
_ten_crop_image_pil,
_vertical_flip_image_pil,
affine, affine,
affine_bounding_boxes, affine_bounding_boxes,
affine_image_pil, affine_image,
affine_image_tensor,
affine_mask, affine_mask,
affine_video, affine_video,
center_crop, center_crop,
center_crop_bounding_boxes, center_crop_bounding_boxes,
center_crop_image_pil, center_crop_image,
center_crop_image_tensor,
center_crop_mask, center_crop_mask,
center_crop_video, center_crop_video,
crop, crop,
crop_bounding_boxes, crop_bounding_boxes,
crop_image_pil, crop_image,
crop_image_tensor,
crop_mask, crop_mask,
crop_video, crop_video,
elastic, elastic,
elastic_bounding_boxes, elastic_bounding_boxes,
elastic_image_pil, elastic_image,
elastic_image_tensor,
elastic_mask, elastic_mask,
elastic_transform, elastic_transform,
elastic_video, elastic_video,
five_crop, five_crop,
five_crop_image_pil, five_crop_image,
five_crop_image_tensor,
five_crop_video, five_crop_video,
hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file
horizontal_flip, horizontal_flip,
horizontal_flip_bounding_boxes, horizontal_flip_bounding_boxes,
horizontal_flip_image_pil, horizontal_flip_image,
horizontal_flip_image_tensor,
horizontal_flip_mask, horizontal_flip_mask,
horizontal_flip_video, horizontal_flip_video,
pad, pad,
pad_bounding_boxes, pad_bounding_boxes,
pad_image_pil, pad_image,
pad_image_tensor,
pad_mask, pad_mask,
pad_video, pad_video,
perspective, perspective,
perspective_bounding_boxes, perspective_bounding_boxes,
perspective_image_pil, perspective_image,
perspective_image_tensor,
perspective_mask, perspective_mask,
perspective_video, perspective_video,
resize, resize,
resize_bounding_boxes, resize_bounding_boxes,
resize_image_pil, resize_image,
resize_image_tensor,
resize_mask, resize_mask,
resize_video, resize_video,
resized_crop, resized_crop,
resized_crop_bounding_boxes, resized_crop_bounding_boxes,
resized_crop_image_pil, resized_crop_image,
resized_crop_image_tensor,
resized_crop_mask, resized_crop_mask,
resized_crop_video, resized_crop_video,
rotate, rotate,
rotate_bounding_boxes, rotate_bounding_boxes,
rotate_image_pil, rotate_image,
rotate_image_tensor,
rotate_mask, rotate_mask,
rotate_video, rotate_video,
ten_crop, ten_crop,
ten_crop_image_pil, ten_crop_image,
ten_crop_image_tensor,
ten_crop_video, ten_crop_video,
vertical_flip, vertical_flip,
vertical_flip_bounding_boxes, vertical_flip_bounding_boxes,
vertical_flip_image_pil, vertical_flip_image,
vertical_flip_image_tensor,
vertical_flip_mask, vertical_flip_mask,
vertical_flip_video, vertical_flip_video,
vflip, vflip,
) )
from ._misc import ( from ._misc import (
_gaussian_blur_image_pil,
convert_image_dtype, convert_image_dtype,
gaussian_blur, gaussian_blur,
gaussian_blur_image_pil, gaussian_blur_image,
gaussian_blur_image_tensor,
gaussian_blur_video, gaussian_blur_video,
normalize, normalize,
normalize_image_tensor, normalize_image,
normalize_video, normalize_video,
to_dtype, to_dtype,
to_dtype_image_tensor, to_dtype_image,
to_dtype_video, to_dtype_video,
) )
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image from ._type_conversion import pil_to_tensor, to_image, to_pil_image
from ._deprecated import get_image_size, to_tensor # usort: skip from ._deprecated import get_image_size, to_tensor # usort: skip
...@@ -18,7 +18,7 @@ def erase( ...@@ -18,7 +18,7 @@ def erase(
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
_log_api_usage_once(erase) _log_api_usage_once(erase)
...@@ -28,7 +28,7 @@ def erase( ...@@ -28,7 +28,7 @@ def erase(
@_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, torch.Tensor)
@_register_kernel_internal(erase, datapoints.Image) @_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor( def erase_image(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if not inplace: if not inplace:
...@@ -39,11 +39,11 @@ def erase_image_tensor( ...@@ -39,11 +39,11 @@ def erase_image_tensor(
@_register_kernel_internal(erase, PIL.Image.Image) @_register_kernel_internal(erase, PIL.Image.Image)
def erase_image_pil( def _erase_image_pil(
image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(image) t_img = pil_to_tensor(image)
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
...@@ -51,4 +51,4 @@ def erase_image_pil( ...@@ -51,4 +51,4 @@ def erase_image_pil(
def erase_video( def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment