Unverified Commit 72dcc170 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Put all dtype conversion stuff into the _misc namespaces (#7770)

parent 35913710
......@@ -39,8 +39,17 @@ from ._geometry import (
ScaleJitter,
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat
from ._misc import (
ConvertImageDtype,
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
SanitizeBoundingBox,
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
......
from typing import Any, Dict, Union
import torch
from torchvision import datapoints, transforms as _transforms
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform
from .utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform):
"""[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY".
......@@ -31,43 +27,6 @@ class ConvertBoundingBoxFormat(Transform):
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)
class ClampBoundingBox(Transform):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions.
......
......@@ -295,6 +295,43 @@ class ToDtype(Transform):
return F.to_dtype(inpt, dtype=dtype, scale=self.scale)
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)
class SanitizeBoundingBox(Transform):
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
......
......@@ -5,10 +5,6 @@ from ._utils import is_simple_tensor # usort: skip
from ._meta import (
clamp_bounding_box,
convert_format_bounding_box,
convert_image_dtype,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
get_dimensions_image_tensor,
get_dimensions_image_pil,
get_dimensions,
......@@ -158,6 +154,7 @@ from ._geometry import (
vflip,
)
from ._misc import (
convert_image_dtype,
gaussian_blur,
gaussian_blur_image_pil,
gaussian_blur_image_tensor,
......@@ -165,6 +162,9 @@ from ._misc import (
normalize,
normalize_image_tensor,
normalize_video,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
)
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image
......
......@@ -9,7 +9,7 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, to_dtype_image_tensor
from ._misc import _num_value_bits, to_dtype_image_tensor
from ._utils import is_simple_tensor
......
......@@ -5,7 +5,6 @@ import torch
from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
......@@ -279,97 +278,3 @@ def clamp_bounding_box(
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
......@@ -6,6 +6,7 @@ import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
......@@ -182,3 +183,97 @@ def gaussian_blur(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment