Unverified Commit 37081ee6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Revamp transforms doc (#7859)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 2c44ebae
.. _datapoints:
Datapoints Datapoints
========== ==========
......
.. _datasets:
Datasets Datasets
======== ========
......
This diff is collapsed.
.. _transforms_gallery:
V2 transforms V2 transforms
------------- -------------
...@@ -235,7 +235,8 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes) ...@@ -235,7 +235,8 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# %% # %%
# Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type` # Alternatively, you can use the :func:`~torchvision.datapoints.set_return_type`
# as a global config setting for the whole program, or as a context manager: # as a global config setting for the whole program, or as a context manager
# (read its docs to learn more about caveats):
with datapoints.set_return_type("datapoint"): with datapoints.set_return_type("datapoint"):
new_bboxes = bboxes + 3 new_bboxes = bboxes + 3
...@@ -274,13 +275,13 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes) ...@@ -274,13 +275,13 @@ assert isinstance(new_bboxes, datapoints.BoundingBoxes)
# ^^^^^^^^^^ # ^^^^^^^^^^
# #
# There are a few exceptions to this "unwrapping" rule: # There are a few exceptions to this "unwrapping" rule:
# :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`,
# :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain
# the datapoint type.
# #
# 1. Operations like :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, # Inplace operations on datapoints like ``obj.add_()`` will preserve the type of
# :meth:`torch.Tensor.detach` and :meth:`~torch.Tensor.requires_grad_` retain # ``obj``. However, the **returned** value of inplace operations will be a pure
# the datapoint type. # tensor:
# 2. Inplace operations on datapoints like ``.add_()`` preserve they type. However,
# the **returned** value of inplace operations will be unwrapped into a pure
# tensor:
image = datapoints.Image([[[0, 1], [1, 0]]]) image = datapoints.Image([[[0, 1], [1, 0]]])
......
...@@ -14,7 +14,7 @@ from torchvision import datapoints ...@@ -14,7 +14,7 @@ from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
from torchvision.transforms.v2.utils import is_pure_tensor from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
...@@ -390,7 +390,7 @@ class TestDispatchers: ...@@ -390,7 +390,7 @@ class TestDispatchers:
assert isinstance(output, type(datapoint)) assert isinstance(output, type(datapoint))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_format_bounding_boxes: if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
assert output.format == datapoint.format assert output.format == datapoint.format
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -445,7 +445,7 @@ class TestDispatchers: ...@@ -445,7 +445,7 @@ class TestDispatchers:
[ [
info info
for info in DISPATCHER_INFOS for info in DISPATCHER_INFOS
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_format_bounding_boxes if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
], ],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes), args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
) )
...@@ -532,19 +532,19 @@ class TestConvertFormatBoundingBoxes: ...@@ -532,19 +532,19 @@ class TestConvertFormatBoundingBoxes:
) )
def test_missing_new_format(self, inpt, old_format): def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")): with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format) F.convert_bounding_box_format(inpt, old_format)
def test_pure_tensor_insufficient_metadata(self): def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor) pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) F.convert_bounding_box_format(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self): def test_datapoint_explicit_metadata(self):
datapoint = next(make_multiple_bounding_boxes()) datapoint = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_boxes( F.convert_bounding_box_format(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
) )
...@@ -611,7 +611,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -611,7 +611,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
] ]
in_boxes = torch.tensor(in_boxes, device=device) in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_boxes( expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size) datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
...@@ -627,7 +627,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -627,7 +627,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_canvas_size, canvas_size) torch.testing.assert_close(output_canvas_size, canvas_size)
...@@ -681,12 +681,12 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig ...@@ -681,12 +681,12 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_boxes(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_boxes(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_canvas_size, size) torch.testing.assert_close(output_canvas_size, size)
...@@ -714,13 +714,13 @@ def test_correctness_pad_bounding_boxes(device, padding): ...@@ -714,13 +714,13 @@ def test_correctness_pad_bounding_boxes(device, padding):
bbox = ( bbox = (
bbox.clone() bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) else convert_bounding_box_format(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format) bbox = convert_bounding_box_format(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype: if bbox.dtype != dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -785,9 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -785,9 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
] ]
) )
bbox_xyxy = convert_format_bounding_boxes( bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -808,7 +806,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -808,7 +806,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
] ]
) )
out_bbox = torch.from_numpy(out_bbox) out_bbox = torch.from_numpy(out_bbox)
out_bbox = convert_format_bounding_boxes( out_bbox = convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_ out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
) )
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox) return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
...@@ -848,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -848,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_boxes(device, output_size): def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_): def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) bbox = convert_bounding_box_format(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
...@@ -862,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size): ...@@ -862,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox[3].item(), bbox[3].item(),
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_boxes(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) out_bbox = convert_bounding_box_format(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size) out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
......
...@@ -342,7 +342,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -342,7 +342,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox): if not torch.is_floating_point(bbox):
bbox = bbox.float() bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_boxes( bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor), bbox.as_subclass(torch.Tensor),
old_format=format, old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
...@@ -366,7 +366,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -366,7 +366,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
], ],
dtype=bbox_xyxy.dtype, dtype=bbox_xyxy.dtype,
) )
out_bbox = F.convert_format_bounding_boxes( out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
......
...@@ -374,8 +374,8 @@ DISPATCHER_INFOS = [ ...@@ -374,8 +374,8 @@ DISPATCHER_INFOS = [
], ],
), ),
DispatcherInfo( DispatcherInfo(
F.convert_format_bounding_boxes, F.convert_bounding_box_format,
kernels={datapoints.BoundingBoxes: F.convert_format_bounding_boxes}, kernels={datapoints.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
......
...@@ -190,7 +190,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -190,7 +190,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox): if not torch.is_floating_point(bbox):
bbox = bbox.float() bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_boxes( bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor), bbox.as_subclass(torch.Tensor),
old_format=format_, old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
...@@ -214,7 +214,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -214,7 +214,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
], ],
dtype=bbox_xyxy.dtype, dtype=bbox_xyxy.dtype,
) )
out_bbox = F.convert_format_bounding_boxes( out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
...@@ -227,30 +227,30 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -227,30 +227,30 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
).reshape(bounding_boxes.shape) ).reshape(bounding_boxes.shape)
def sample_inputs_convert_format_bounding_boxes(): def sample_inputs_convert_bounding_box_format():
formats = list(datapoints.BoundingBoxFormat) formats = list(datapoints.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format) yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
def reference_convert_format_bounding_boxes(bounding_boxes, old_format, new_format): def reference_convert_bounding_box_format(bounding_boxes, old_format, new_format):
return torchvision.ops.box_convert( return torchvision.ops.box_convert(
bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower() bounding_boxes, in_fmt=old_format.name.lower(), out_fmt=new_format.name.lower()
).to(bounding_boxes.dtype) ).to(bounding_boxes.dtype)
def reference_inputs_convert_format_bounding_boxes(): def reference_inputs_convert_bounding_box_format():
for args_kwargs in sample_inputs_convert_format_bounding_boxes(): for args_kwargs in sample_inputs_convert_bounding_box_format():
if len(args_kwargs.args[0].shape) == 2: if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs yield args_kwargs
KERNEL_INFOS.append( KERNEL_INFOS.append(
KernelInfo( KernelInfo(
F.convert_format_bounding_boxes, F.convert_bounding_box_format,
sample_inputs_fn=sample_inputs_convert_format_bounding_boxes, sample_inputs_fn=sample_inputs_convert_bounding_box_format,
reference_fn=reference_convert_format_bounding_boxes, reference_fn=reference_convert_bounding_box_format,
reference_inputs_fn=reference_inputs_convert_format_bounding_boxes, reference_inputs_fn=reference_inputs_convert_bounding_box_format,
logs_usage=True, logs_usage=True,
closeness_kwargs={ closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0), (("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
......
...@@ -368,7 +368,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys): ...@@ -368,7 +368,7 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
target["image_id"] = image_id target["image_id"] = image_id
if "boxes" in target_keys: if "boxes" in target_keys:
target["boxes"] = F.convert_format_bounding_boxes( target["boxes"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes( datapoints.BoundingBoxes(
batched_target["bbox"], batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
...@@ -489,7 +489,7 @@ def celeba_wrapper_factory(dataset, target_keys): ...@@ -489,7 +489,7 @@ def celeba_wrapper_factory(dataset, target_keys):
target, target,
target_types=dataset.target_type, target_types=dataset.target_type,
type_wrappers={ type_wrappers={
"bbox": lambda item: F.convert_format_bounding_boxes( "bbox": lambda item: F.convert_bounding_box_format(
datapoints.BoundingBoxes( datapoints.BoundingBoxes(
item, item,
format=datapoints.BoundingBoxFormat.XYWH, format=datapoints.BoundingBoxFormat.XYWH,
...@@ -636,7 +636,7 @@ def widerface_wrapper(dataset, target_keys): ...@@ -636,7 +636,7 @@ def widerface_wrapper(dataset, target_keys):
target = {key: target[key] for key in target_keys} target = {key: target[key] for key in target_keys}
if "bbox" in target_keys: if "bbox" in target_keys:
target["bbox"] = F.convert_format_bounding_boxes( target["bbox"] = F.convert_bounding_box_format(
datapoints.BoundingBoxes( datapoints.BoundingBoxes(
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width) target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, canvas_size=(image.height, image.width)
), ),
......
...@@ -22,6 +22,13 @@ def set_return_type(return_type: str): ...@@ -22,6 +22,13 @@ def set_return_type(return_type: str):
``torchvision`` transforms or functionals, which will always return as ``torchvision`` transforms or functionals, which will always return as
output the same type that was passed as input. output the same type that was passed as input.
.. warning::
We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at
the end of your transform pipelines if you use
``set_return_type("dataptoint")``. This will avoid the
``__torch_function__`` overhead in the models ``forward()``.
Can be used as a global flag for the entire program: Can be used as a global flag for the entire program:
.. code:: python .. code:: python
......
...@@ -80,7 +80,7 @@ class SimpleCopyPaste(Transform): ...@@ -80,7 +80,7 @@ class SimpleCopyPaste(Transform):
# There is a similar +1 in other reference implementations: # There is a similar +1 in other reference implementations:
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1 xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_boxes( boxes = F.convert_bounding_box_format(
xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True xyxy_boxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
) )
out_target["boxes"] = torch.cat([boxes, paste_boxes]) out_target["boxes"] = torch.cat([boxes, paste_boxes])
...@@ -89,7 +89,7 @@ class SimpleCopyPaste(Transform): ...@@ -89,7 +89,7 @@ class SimpleCopyPaste(Transform):
out_target["labels"] = torch.cat([labels, paste_labels]) out_target["labels"] = torch.cat([labels, paste_labels])
# Check for degenerated boxes and remove them # Check for degenerated boxes and remove them
boxes = F.convert_format_bounding_boxes( boxes = F.convert_bounding_box_format(
out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY out_target["boxes"], old_format=bbox_format, new_format=datapoints.BoundingBoxFormat.XYXY
) )
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
......
...@@ -76,7 +76,7 @@ class FixedSizeCrop(Transform): ...@@ -76,7 +76,7 @@ class FixedSizeCrop(Transform):
width=new_width, width=new_width,
) )
bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size) bounding_boxes = F.clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size)
height_and_width = F.convert_format_bounding_boxes( height_and_width = F.convert_bounding_box_format(
bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH bounding_boxes, old_format=format, new_format=datapoints.BoundingBoxFormat.XYWH
)[..., 2:] )[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1) is_valid = torch.all(height_and_width > 0, dim=-1)
......
...@@ -10,13 +10,15 @@ from torchvision.transforms.v2 import Transform ...@@ -10,13 +10,15 @@ from torchvision.transforms.v2 import Transform
class ToTensor(Transform): class ToTensor(Transform):
"""[BETA] Convert a PIL Image or ndarray to tensor and scale the values accordingly. """[BETA] [DEPRECATED] Use ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`` instead.
Convert a PIL Image or ndarray to tensor and scale the values accordingly.
.. v2betastatus:: ToTensor transform .. v2betastatus:: ToTensor transform
.. warning:: .. warning::
:class:`v2.ToTensor` is deprecated and will be removed in a future release. :class:`v2.ToTensor` is deprecated and will be removed in a future release.
Please use instead ``v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])``. Please use instead ``v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])``.
This transform does not support torchscript. This transform does not support torchscript.
...@@ -40,7 +42,7 @@ class ToTensor(Transform): ...@@ -40,7 +42,7 @@ class ToTensor(Transform):
def __init__(self) -> None: def __init__(self) -> None:
warnings.warn( warnings.warn(
"The transform `ToTensor()` is deprecated and will be removed in a future release. " "The transform `ToTensor()` is deprecated and will be removed in a future release. "
"Instead, please use `v2.Compose([transforms.ToImageTensor(), v2.ToDtype(torch.float32, scale=True)])`." "Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`."
) )
super().__init__() super().__init__()
......
...@@ -1186,7 +1186,7 @@ class RandomIoUCrop(Transform): ...@@ -1186,7 +1186,7 @@ class RandomIoUCrop(Transform):
continue continue
# check for any valid boxes with centers within the crop area # check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_boxes( xyxy_bboxes = F.convert_bounding_box_format(
bboxes.as_subclass(torch.Tensor), bboxes.as_subclass(torch.Tensor),
bboxes.format, bboxes.format,
datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYXY,
......
...@@ -24,7 +24,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -24,7 +24,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format self.format = format
def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes: def _transform(self, inpt: datapoints.BoundingBoxes, params: Dict[str, Any]) -> datapoints.BoundingBoxes:
return F.convert_format_bounding_boxes(inpt, new_format=self.format) # type: ignore[return-value] return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
class ClampBoundingBoxes(Transform): class ClampBoundingBoxes(Transform):
......
...@@ -293,7 +293,9 @@ class ToDtype(Transform): ...@@ -293,7 +293,9 @@ class ToDtype(Transform):
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly. """[BETA] [DEPRECATED] Use ``v2.ToDtype(dtype, scale=True)`` instead.
Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform .. v2betastatus:: ConvertImageDtype transform
...@@ -388,7 +390,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -388,7 +390,7 @@ class SanitizeBoundingBoxes(Transform):
boxes = cast( boxes = cast(
datapoints.BoundingBoxes, datapoints.BoundingBoxes,
F.convert_format_bounding_boxes( F.convert_bounding_box_format(
boxes, boxes,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=datapoints.BoundingBoxFormat.XYXY,
), ),
......
...@@ -4,7 +4,7 @@ from ._utils import is_pure_tensor, register_kernel # usort: skip ...@@ -4,7 +4,7 @@ from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_boxes, clamp_bounding_boxes,
convert_format_bounding_boxes, convert_bounding_box_format,
get_dimensions_image, get_dimensions_image,
_get_dimensions_image_pil, _get_dimensions_image_pil,
get_dimensions_video, get_dimensions_video,
......
...@@ -17,6 +17,7 @@ def erase( ...@@ -17,6 +17,7 @@ def erase(
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomErase` for details."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......
...@@ -15,6 +15,7 @@ from ._utils import _get_kernel, _register_kernel_internal ...@@ -15,6 +15,7 @@ from ._utils import _get_kernel, _register_kernel_internal
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.Grayscale` for details."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels) return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels)
...@@ -69,6 +70,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te ...@@ -69,6 +70,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
"""Adjust brightness."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_brightness_image(inpt, brightness_factor=brightness_factor) return adjust_brightness_image(inpt, brightness_factor=brightness_factor)
...@@ -106,6 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -106,6 +108,7 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
"""Adjust saturation."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_saturation_image(inpt, saturation_factor=saturation_factor) return adjust_saturation_image(inpt, saturation_factor=saturation_factor)
...@@ -144,6 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -144,6 +147,7 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.RandomAutocontrast`"""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_contrast_image(inpt, contrast_factor=contrast_factor) return adjust_contrast_image(inpt, contrast_factor=contrast_factor)
...@@ -182,6 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -182,6 +186,7 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.RandomAdjustSharpness`"""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor)
...@@ -254,6 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -254,6 +259,7 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
"""Adjust hue"""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_hue_image(inpt, hue_factor=hue_factor) return adjust_hue_image(inpt, hue_factor=hue_factor)
...@@ -371,6 +377,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -371,6 +377,7 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
"""Adjust gamma."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_gamma_image(inpt, gamma=gamma, gain=gain) return adjust_gamma_image(inpt, gamma=gamma, gain=gain)
...@@ -410,6 +417,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -410,6 +417,7 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomPosterize` for details."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return posterize_image(inpt, bits=bits) return posterize_image(inpt, bits=bits)
...@@ -443,6 +451,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -443,6 +451,7 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomSolarize` for details."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return solarize_image(inpt, threshold=threshold) return solarize_image(inpt, threshold=threshold)
...@@ -470,6 +479,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -470,6 +479,7 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
def autocontrast(inpt: torch.Tensor) -> torch.Tensor: def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomAutocontrast` for details."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return autocontrast_image(inpt) return autocontrast_image(inpt)
...@@ -519,6 +529,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -519,6 +529,7 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
def equalize(inpt: torch.Tensor) -> torch.Tensor: def equalize(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :class:`~torchvision.transforms.v2.RandomEqualize` for details."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return equalize_image(inpt) return equalize_image(inpt)
...@@ -608,6 +619,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -608,6 +619,7 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
def invert(inpt: torch.Tensor) -> torch.Tensor: def invert(inpt: torch.Tensor) -> torch.Tensor:
"""[BETA] See :func:`~torchvision.transforms.v2.RandomInvert`."""
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return invert_image(inpt) return invert_image(inpt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment