Unverified Commit 26ed129d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make v2.utils private. (#7863)

parent 9c4f7389
...@@ -11,7 +11,7 @@ class PadIfSmaller(v2.Transform): ...@@ -11,7 +11,7 @@ class PadIfSmaller(v2.Transform):
self.fill = v2._utils._setup_fill_arg(fill) self.fill = v2._utils._setup_fill_arg(fill)
def _get_params(self, sample): def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample) _, height, width = v2._utils.query_chw(sample)
padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)] padding = [0, 0, max(self.size - width, 0), max(self.size - height, 0)]
needs_padding = any(padding) needs_padding = any(padding)
return dict(padding=padding, needs_padding=needs_padding) return dict(padding=padding, needs_padding=needs_padding)
......
...@@ -25,7 +25,7 @@ from torchvision.prototype import datasets ...@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
def assert_samples_equal(*args, msg=None, **kwargs): def assert_samples_equal(*args, msg=None, **kwargs):
......
...@@ -10,8 +10,8 @@ from prototype_common_utils import make_label ...@@ -10,8 +10,8 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_pure_tensor
from transforms_v2_legacy_utils import ( from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS, DEFAULT_EXTRA_DIMS,
make_bounding_boxes, make_bounding_boxes,
......
...@@ -16,7 +16,7 @@ from torchvision import datapoints ...@@ -16,7 +16,7 @@ from torchvision import datapoints
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw from torchvision.transforms.v2._utils import check_type, is_pure_tensor, query_chw
from transforms_v2_legacy_utils import ( from transforms_v2_legacy_utils import (
make_bounding_boxes, make_bounding_boxes,
make_detection_mask, make_detection_mask,
......
...@@ -19,9 +19,8 @@ from torchvision._utils import sequence_to_str ...@@ -19,9 +19,8 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill from torchvision.transforms.v2._utils import _get_fill, query_size
from torchvision.transforms.v2.functional import to_pil_image from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import query_size
from transforms_v2_legacy_utils import ( from transforms_v2_legacy_utils import (
ArgsKwargs, ArgsKwargs,
make_bounding_boxes, make_bounding_boxes,
......
...@@ -13,9 +13,9 @@ from torch.utils._pytree import tree_map ...@@ -13,9 +13,9 @@ from torch.utils._pytree import tree_map
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_bounding_box_format
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
from transforms_v2_legacy_utils import ( from transforms_v2_legacy_utils import (
......
...@@ -3,12 +3,12 @@ import pytest ...@@ -3,12 +3,12 @@ import pytest
import torch import torch
import torchvision.transforms.v2.utils import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2._utils import has_all, has_any
from torchvision.transforms.v2.functional import to_pil_image from torchvision.transforms.v2.functional import to_pil_image
from torchvision.transforms.v2.utils import has_all, has_any
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB") IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
...@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE) ...@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True), ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
( (
(torch.Tensor(IMAGE),), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True, True,
), ),
( (
(to_pil_image(IMAGE),), (to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True, True,
), ),
], ],
......
...@@ -7,9 +7,9 @@ from torchvision import datapoints ...@@ -7,9 +7,9 @@ from torchvision import datapoints
from torchvision.ops import masks_to_boxes from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints as proto_datapoints from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2._utils import is_pure_tensor
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_pure_tensor
class SimpleCopyPaste(Transform): class SimpleCopyPaste(Transform):
......
...@@ -6,8 +6,16 @@ import torch ...@@ -6,8 +6,16 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size from torchvision.transforms.v2._utils import (
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size _FillType,
_get_fill,
_setup_fill_arg,
_setup_size,
get_bounding_boxes,
has_any,
is_pure_tensor,
query_size,
)
class FixedSizeCrop(Transform): class FixedSizeCrop(Transform):
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import Transform from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2.utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
T = TypeVar("T") T = TypeVar("T")
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from . import functional, utils # usort: skip from . import functional # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
......
...@@ -11,8 +11,7 @@ from torchvision import datapoints, transforms as _transforms ...@@ -11,8 +11,7 @@ from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter from ._utils import _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
from .utils import has_any, is_pure_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
......
...@@ -12,8 +12,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation ...@@ -12,8 +12,7 @@ from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg from ._utils import _get_fill, _setup_fill_arg, check_type, is_pure_tensor
from .utils import check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
......
...@@ -6,7 +6,7 @@ from torchvision import transforms as _transforms ...@@ -6,7 +6,7 @@ from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from .utils import query_chw from ._utils import query_chw
class Grayscale(Transform): class Grayscale(Transform):
......
...@@ -23,8 +23,12 @@ from ._utils import ( ...@@ -23,8 +23,12 @@ from ._utils import (
_setup_fill_arg, _setup_fill_arg,
_setup_float_or_seq, _setup_float_or_seq,
_setup_size, _setup_size,
get_bounding_boxes,
has_all,
has_any,
is_pure_tensor,
query_size,
) )
from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
......
...@@ -9,8 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten ...@@ -9,8 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints, transforms as _transforms from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
from .utils import get_bounding_boxes, has_any, is_pure_tensor
# TODO: do we want/need to expose this? # TODO: do we want/need to expose this?
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from .functional._utils import _get_kernel from .functional._utils import _get_kernel
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
class PILToTensor(Transform): class PILToTensor(Transform):
......
from __future__ import annotations
import collections.abc import collections.abc
import numbers import numbers
from contextlib import suppress from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch import torch
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
...@@ -138,3 +147,73 @@ def _parse_labels_getter( ...@@ -138,3 +147,73 @@ def _parse_labels_getter(
return lambda _: None return lambda _: None
else: else:
raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.") raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w
def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_size(inpt))
for inpt in flat_inputs
if check_type(
inpt,
(
is_pure_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
),
)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w
def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False
def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
from __future__ import annotations
from typing import Any, Callable, List, Tuple, Type, Union
import PIL.Image
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
# This assumes there is only one bbox per sample as per the general convention
try:
return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
except StopIteration:
raise ValueError("No bounding boxes were found in the sample")
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
elif len(chws) > 1:
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
c, h, w = chws.pop()
return c, h, w
def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_size(inpt))
for inpt in flat_inputs
if check_type(
inpt,
(
is_pure_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
datapoints.Mask,
datapoints.BoundingBoxes,
),
)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
elif len(sizes) > 1:
raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
h, w = sizes.pop()
return h, w
def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
for type_or_check in types_or_checks:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
return True
return False
def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if check_type(inpt, types_or_checks):
return True
return False
def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment