Unverified Commit 37081ee6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Revamp transforms doc (#7859)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 2c44ebae
.. _datapoints:
Datapoints
==========
......
.. _datasets:
Datasets
========
......
This diff is collapsed.
.. _transforms_gallery:
V2 transforms
-------------
......@@ -235,7 +235,8 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %%
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager:
# as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
with datapoints.set_return_type("datapoint"):
new_bboxes = bboxes + 3
......@@ -274,13 +275,13 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# ^^^^^^^^^^
#
# There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
#
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However,
# the **returned** value of inplace operations will be unwrapped into a pure
# tensor:
# Inplace operations on datapoints like ``obj.add_()`` will preserve the type of
# ``obj``. However, the **returned** value of inplace operations will be a pure
# tensor:
image = datapoints.Image([[[0, 1], [1, 0]]])
......
......@@ -14,7 +14,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
......@@ -390,7 +390,7 @@ class TestDispatchers:
assert isinstance(output, type(datapoint))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes:
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
assert output.format == datapoint.format
@pytest.mark.parametrize(
......@@ -445,7 +445,7 @@ class TestDispatchers:
[
info
for info in DISPATCHER_INFOS
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
)
......@@ -532,19 +532,19 @@ class TestConvertFormatBoundingBoxes:
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format)
F.convert_bounding_box_format(inpt, old_format)
def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
F.convert_bounding_box_format(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self):
datapoint = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_boxes(
F.convert_bounding_box_format(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
)
......@@ -611,7 +611,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
]
in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
......@@ -627,7 +627,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_canvas_size, canvas_size)
......@@ -681,12 +681,12 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_canvas_size, size)
......@@ -714,13 +714,13 @@ def test_correctness_pad_bounding_boxes(device, padding):
bbox = (
bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
else convert_bounding_box_format(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
)
bbox[0::2] += pad_left
bbox[1::2] += pad_up
bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
bbox = convert_bounding_box_format(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype:
# Temporary cast to original dtype
# e.g. float32 -> int
......@@ -785,9 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
)
bbox_xyxy = convert_format_bounding_boxes(
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
)
bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
......@@ -808,7 +806,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
)
out_bbox = torch.from_numpy(out_bbox)
out_bbox = convert_format_bounding_boxes(
out_bbox = convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
)
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
......@@ -848,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
dtype = bbox.dtype
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
bbox = convert_bounding_box_format(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1:
output_size_.append(output_size_[-1])
......@@ -862,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox[3].item(),
]
out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = convert_bounding_box_format(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device)
......
......@@ -342,7 +342,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_boxes(
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY,
......@@ -366,7 +366,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_format_bounding_boxes(
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
......
......@@ -374,8 +374,8 @@ DISPATCHER_INFOS = [
],
),
DispatcherInfo(
F.convert_format_bounding_boxes,
kernels={datapoints.BoundingBoxes: F.convert_format_bounding_boxes},
F.convert_bounding_box_format,
kernels={datapoints.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[
skip_dispatch_datapoint,
],
......
......@@ -190,7 +190,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
in_dtype = bbox.dtype
if not torch.is_floating_point(bbox):
bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_boxes(
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY,
......@@ -214,7 +214,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
],
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_format_bounding_boxes(
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
......@@ -227,30 +227,30 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
).reshape(bounding_boxes.shape)
def sample_inputs_convert_format_bounding_boxes():
def sample_inputs_convert_bounding_box_format():
formats = list(datapoints.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format):
def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype)
def reference_inputs_convert_format_bounding_boxes():
for args_kwargs in sample_inputs_convert_format_bounding_boxes():
def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
KERNEL_INFOS.append(
KernelInfo(
F.convert_format_bounding_boxes,
sample_inputs_fn=sample_inputs_convert_format_bounding_boxes,
reference_fn=reference_convert_format_bounding_boxes,
reference_inputs_fn=reference_inputs_convert_format_bounding_boxes,
F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
......
......@@ -368,7 +368,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
target["image_id"] = image_id
if "boxes" in target_keys:
target["boxes"] = F.convert_format_bounding_boxes(
target["boxes"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes(
batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH,
......@@ -489,7 +489,7 @@ def celeba_wrapper_factory(dataset, target_keys):
target,
target_types=dataset.target_type,
type_wrappers={
"bbox": lambda item: F.convert_format_bounding_boxes(
"bbox": lambda item: F.convert_bounding_box_format(
datapoints.BoundingBoxes(
item,
format=datapoints.BoundingBoxFormat.XYWH,
......@@ -636,7 +636,7 @@ def widerface_wrapper(dataset, target_keys):
target = {key: target[key] for key in target_keys}
if "bbox" in target_keys:
target["bbox"] = F.convert_format_bounding_boxes(
target["bbox"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
),
......
......@@ -22,6 +22,13 @@ def set_return_type(return_type: str):
``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input.
.. warning::
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use
``set_return_type("dataptoint")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program:
.. code:: python
......
......@@ -80,7 +80,7 @@ class SimpleCopyPaste(Transform):
# There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_boxes(
boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
)
out_target["boxes"] = torch.cat([boxes, paste_boxes])
......@@ -89,7 +89,7 @@ class SimpleCopyPaste(Transform):
out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_boxes(
boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
)
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
......
......@@ -76,7 +76,7 @@ class FixedSizeCrop(Transform):
width=new_width,
)
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_format_bounding_boxes(
height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
......
......@@ -10,13 +10,15 @@ from torchvision.transforms.v2 import Transform
class ToTensor(Transform):
"""[BETA] Convert a PIL Image or ndarray to tensor and scale the values accordingly.
"""[BETA] [DEPRECATED] Use ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`` instead.
Convert a PIL Image or ndarray to tensor and scale the values accordingly.
.. v2betastatus:: ToTensor transform
.. warning::
:class:`v2.ToTensor` is deprecated and will be removed in a future release.
Please use instead ``v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])``.
Please use instead ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])``.
This transform does not support torchscript.
......@@ -40,7 +42,7 @@ class ToTensor(Transform):
def __init__(self) -> None:
warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])`."
"Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`."
)
super().__init__()
......
......@@ -1186,7 +1186,7 @@ class RandomIoUCrop(Transform):
continue
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_boxes(
xyxy_bboxes = F.convert_bounding_box_format(
bboxes.as_subclass(torch.Tensor),
bboxes.format,
datapoints.BoundingBoxFormat.XYXY,
......
......@@ -24,7 +24,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
return F.convert_format_bounding_boxes(inpt, new_format=self.format) # type: ignore[return-value]
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
class ClampBoundingBoxes(Transform):
......
......@@ -293,7 +293,9 @@ class ToDtype(Transform):
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
"""[BETA] [DEPRECATED] Use ``v2.ToDtype(dtype, scale=True)`` instead.
Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
......@@ -388,7 +390,7 @@ class SanitizeBoundingBoxes(Transform):
boxes = cast(
datapoints.BoundingBoxes,
F.convert_format_bounding_boxes(
F.convert_bounding_box_format(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
......
......@@ -4,7 +4,7 @@ from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._meta import (
clamp_bounding_boxes,
convert_format_bounding_boxes,
convert_bounding_box_format,
get_dimensions_image,
_get_dimensions_image_pil,
get_dimensions_video,
......
......@@ -17,6 +17,7 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomErase` for details."""
if torch.jit.is_scripting():
return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......
......@@ -15,6 +15,7 @@ from ._utils import _get_kernel, _register_kernel_internal
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.Grayscale` for details."""
if torch.jit.is_scripting():
return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels)
......@@ -69,6 +70,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
"""Adjust brightness."""
if torch.jit.is_scripting():
return adjust_brightness_image(inpt, brightness_factor=brightness_factor)
......@@ -106,6 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
"""Adjust saturation."""
if torch.jit.is_scripting():
return adjust_saturation_image(inpt, saturation_factor=saturation_factor)
......@@ -144,6 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.RandomAutocontrast`"""
if torch.jit.is_scripting():
return adjust_contrast_image(inpt, contrast_factor=contrast_factor)
......@@ -182,6 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.RandomAdjustSharpness`"""
if torch.jit.is_scripting():
return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor)
......@@ -254,6 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
"""Adjust hue"""
if torch.jit.is_scripting():
return adjust_hue_image(inpt, hue_factor=hue_factor)
......@@ -371,6 +377,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
"""Adjust gamma."""
if torch.jit.is_scripting():
return adjust_gamma_image(inpt, gamma=gamma, gain=gain)
......@@ -410,6 +417,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomPosterize` for details."""
if torch.jit.is_scripting():
return posterize_image(inpt, bits=bits)
......@@ -443,6 +451,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomSolarize` for details."""
if torch.jit.is_scripting():
return solarize_image(inpt, threshold=threshold)
......@@ -470,6 +479,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomAutocontrast` for details."""
if torch.jit.is_scripting():
return autocontrast_image(inpt)
......@@ -519,6 +529,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def equalize(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomEqualize` for details."""
if torch.jit.is_scripting():
return equalize_image(inpt)
......@@ -608,6 +619,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def invert(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.RandomInvert`."""
if torch.jit.is_scripting():
return invert_image(inpt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment