"...text-generation-inference.git" did not exist on "e605c2a43e693844cb2c5ba879f41392faf64793"
Unverified Commit 72dcc170 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Put all dtype conversion stuff into the _misc namespaces (#7770)

parent 35913710
...@@ -39,8 +39,17 @@ from ._geometry import ( ...@@ -39,8 +39,17 @@ from ._geometry import (
ScaleJitter, ScaleJitter,
TenCrop, TenCrop,
) )
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertImageDtype from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype from ._misc import (
ConvertImageDtype,
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
SanitizeBoundingBox,
ToDtype,
)
from ._temporal import UniformTemporalSubsample from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
......
from typing import Any, Dict, Union from typing import Any, Dict, Union
import torch from torchvision import datapoints
from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from .utils import is_simple_tensor
class ConvertBoundingBoxFormat(Transform): class ConvertBoundingBoxFormat(Transform):
"""[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY". """[BETA] Convert bounding box coordinates to the given ``format``, eg from "CXCYWH" to "XYXY".
...@@ -31,43 +27,6 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -31,43 +27,6 @@ class ConvertBoundingBoxFormat(Transform):
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value] return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)
class ClampBoundingBox(Transform): class ClampBoundingBox(Transform):
"""[BETA] Clamp bounding boxes to their corresponding image dimensions. """[BETA] Clamp bounding boxes to their corresponding image dimensions.
......
...@@ -295,6 +295,43 @@ class ToDtype(Transform): ...@@ -295,6 +295,43 @@ class ToDtype(Transform):
return F.to_dtype(inpt, dtype=dtype, scale=self.scale) return F.to_dtype(inpt, dtype=dtype, scale=self.scale)
class ConvertImageDtype(Transform):
"""[BETA] Convert input image to the given ``dtype`` and scale the values accordingly.
.. v2betastatus:: ConvertImageDtype transform
.. warning::
Consider using ``ToDtype(dtype, scale=True)`` instead. See :class:`~torchvision.transforms.v2.ToDtype`.
This function does not support PIL Image.
Args:
dtype (torch.dtype): Desired data type of the output
.. note::
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
If converted back and forth, this mismatch has no effect.
Raises:
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
of the integer ``dtype``.
"""
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)
class SanitizeBoundingBox(Transform): class SanitizeBoundingBox(Transform):
"""[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks. """[BETA] Remove degenerate/invalid bounding boxes and their corresponding labels and masks.
......
...@@ -5,10 +5,6 @@ from ._utils import is_simple_tensor # usort: skip ...@@ -5,10 +5,6 @@ from ._utils import is_simple_tensor # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_box, clamp_bounding_box,
convert_format_bounding_box, convert_format_bounding_box,
convert_image_dtype,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
get_dimensions_image_tensor, get_dimensions_image_tensor,
get_dimensions_image_pil, get_dimensions_image_pil,
get_dimensions, get_dimensions,
...@@ -158,6 +154,7 @@ from ._geometry import ( ...@@ -158,6 +154,7 @@ from ._geometry import (
vflip, vflip,
) )
from ._misc import ( from ._misc import (
convert_image_dtype,
gaussian_blur, gaussian_blur,
gaussian_blur_image_pil, gaussian_blur_image_pil,
gaussian_blur_image_tensor, gaussian_blur_image_tensor,
...@@ -165,6 +162,9 @@ from ._misc import ( ...@@ -165,6 +162,9 @@ from ._misc import (
normalize, normalize,
normalize_image_tensor, normalize_image_tensor,
normalize_video, normalize_video,
to_dtype,
to_dtype_image_tensor,
to_dtype_video,
) )
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image
......
...@@ -9,7 +9,7 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -9,7 +9,7 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._meta import _num_value_bits, to_dtype_image_tensor from ._misc import _num_value_bits, to_dtype_image_tensor
from ._utils import is_simple_tensor from ._utils import is_simple_tensor
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.datapoints import BoundingBoxFormat from torchvision.datapoints import BoundingBoxFormat
from torchvision.transforms import _functional_pil as _FP from torchvision.transforms import _functional_pil as _FP
from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -279,97 +278,3 @@ def clamp_bounding_box( ...@@ -279,97 +278,3 @@ def clamp_bounding_box(
raise TypeError( raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
) )
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from torch.nn.functional import conv2d, pad as torch_pad from torch.nn.functional import conv2d, pad as torch_pad
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -182,3 +183,97 @@ def gaussian_blur( ...@@ -182,3 +183,97 @@ def gaussian_blur(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
elif dtype == torch.int8:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
return 63
else:
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if image.dtype == dtype:
return image
elif not scale:
return image.to(dtype)
float_input = image.is_floating_point()
if torch.jit.is_scripting():
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
else:
float_output = dtype.is_floating_point
if float_input:
# float to float
if float_output:
return image.to(dtype)
# float to int
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
image.dtype == torch.float64 and dtype == torch.int64
):
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
# for a detailed analysis.
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
# Instead, we can also multiply by the maximum value plus something close to `1`. See
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
eps = 1e-3
max_value = float(_max_value(dtype))
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
# discrete set `{0, 1}`.
return image.mul(max_value + 1.0 - eps).to(dtype)
else:
# int to float
if float_output:
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
# int to int
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# We encourage users to use to_dtype() instead but we keep this for BC
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
return to_dtype_image_tensor(image, dtype=dtype, scale=True)
def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
return to_dtype_image_tensor(video, dtype, scale=scale)
def to_dtype(inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(to_dtype)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return to_dtype_image_tensor(inpt, dtype, scale=scale)
elif isinstance(inpt, datapoints.Image):
output = to_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = to_dtype_video(inpt.as_subclass(torch.Tensor), dtype, scale=scale)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to(dtype)
else:
raise TypeError(f"Input can either be a plain tensor or a datapoint, but got {type(inpt)} instead.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment