Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d5f4cc38
Unverified
Commit
d5f4cc38
authored
Aug 30, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 30, 2023
Browse files
Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)
parent
b9447fdd
Changes
85
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
384 additions
and
384 deletions
+384
-384
torchvision/prototype/tv_tensors/_label.py
torchvision/prototype/tv_tensors/_label.py
+2
-2
torchvision/transforms/v2/_augment.py
torchvision/transforms/v2/_augment.py
+11
-11
torchvision/transforms/v2/_auto_augment.py
torchvision/transforms/v2/_auto_augment.py
+8
-8
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+58
-58
torchvision/transforms/v2/_meta.py
torchvision/transforms/v2/_meta.py
+9
-9
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+18
-18
torchvision/transforms/v2/_transform.py
torchvision/transforms/v2/_transform.py
+4
-4
torchvision/transforms/v2/_type_conversion.py
torchvision/transforms/v2/_type_conversion.py
+6
-6
torchvision/transforms/v2/_utils.py
torchvision/transforms/v2/_utils.py
+8
-8
torchvision/transforms/v2/functional/_augment.py
torchvision/transforms/v2/functional/_augment.py
+3
-3
torchvision/transforms/v2/functional/_color.py
torchvision/transforms/v2/functional/_color.py
+27
-27
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+126
-126
torchvision/transforms/v2/functional/_meta.py
torchvision/transforms/v2/functional/_meta.py
+22
-22
torchvision/transforms/v2/functional/_misc.py
torchvision/transforms/v2/functional/_misc.py
+10
-10
torchvision/transforms/v2/functional/_temporal.py
torchvision/transforms/v2/functional/_temporal.py
+2
-2
torchvision/transforms/v2/functional/_type_conversion.py
torchvision/transforms/v2/functional/_type_conversion.py
+3
-3
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+24
-24
torchvision/tv_tensors/__init__.py
torchvision/tv_tensors/__init__.py
+5
-5
torchvision/tv_tensors/_bounding_box.py
torchvision/tv_tensors/_bounding_box.py
+4
-4
torchvision/tv_tensors/_dataset_wrapper.py
torchvision/tv_tensors/_dataset_wrapper.py
+34
-34
No files found.
torchvision/prototype/
datapoint
s/_label.py
→
torchvision/prototype/
tv_tensor
s/_label.py
View file @
d5f4cc38
...
...
@@ -5,13 +5,13 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import
torch
from
torch.utils._pytree
import
tree_map
from
torchvision.
datapoints._datapoint
import
Datapoint
from
torchvision.
tv_tensors._tv_tensor
import
TVTensor
L
=
TypeVar
(
"L"
,
bound
=
"_LabelBase"
)
class
_LabelBase
(
Datapoint
):
class
_LabelBase
(
TVTensor
):
categories
:
Optional
[
Sequence
[
str
]]
@
classmethod
...
...
torchvision/transforms/v2/_augment.py
View file @
d5f4cc38
...
...
@@ -7,7 +7,7 @@ import PIL.Image
import
torch
from
torch.nn.functional
import
one_hot
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
transforms
as
_transforms
,
tv_tensors
from
torchvision.transforms.v2
import
functional
as
F
from
._transform
import
_RandomApplyTransform
,
Transform
...
...
@@ -91,10 +91,10 @@ class RandomErasing(_RandomApplyTransform):
self
.
_log_ratio
=
torch
.
log
(
torch
.
tensor
(
self
.
ratio
))
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)):
if
isinstance
(
inpt
,
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"
datapoint
s.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
f
"
tv_tensor
s.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
...
...
@@ -158,7 +158,7 @@ class _BaseMixUpCutMix(Transform):
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
):
if
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
):
raise
ValueError
(
f
"
{
type
(
self
).
__name__
}
() does not support PIL images, bounding boxes and masks."
)
labels
=
self
.
_labels_getter
(
inputs
)
...
...
@@ -188,7 +188,7 @@ class _BaseMixUpCutMix(Transform):
return
tree_unflatten
(
flat_outputs
,
spec
)
def
_check_image_or_video
(
self
,
inpt
:
torch
.
Tensor
,
*
,
batch_size
:
int
):
expected_num_dims
=
5
if
isinstance
(
inpt
,
datapoint
s
.
Video
)
else
4
expected_num_dims
=
5
if
isinstance
(
inpt
,
tv_tensor
s
.
Video
)
else
4
if
inpt
.
ndim
!=
expected_num_dims
:
raise
ValueError
(
f
"Expected a batched input with
{
expected_num_dims
}
dims, but got
{
inpt
.
ndim
}
dimensions instead."
...
...
@@ -242,13 +242,13 @@ class MixUp(_BaseMixUpCutMix):
if
inpt
is
params
[
"labels"
]:
return
self
.
_mixup_label
(
inpt
,
lam
=
lam
)
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
))
or
is_pure_tensor
(
inpt
):
elif
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
))
or
is_pure_tensor
(
inpt
):
self
.
_check_image_or_video
(
inpt
,
batch_size
=
params
[
"batch_size"
])
output
=
inpt
.
roll
(
1
,
0
).
mul_
(
1.0
-
lam
).
add_
(
inpt
.
mul
(
lam
))
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
)):
output
=
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
if
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
)):
output
=
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
return
output
else
:
...
...
@@ -309,7 +309,7 @@ class CutMix(_BaseMixUpCutMix):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
if
inpt
is
params
[
"labels"
]:
return
self
.
_mixup_label
(
inpt
,
lam
=
params
[
"lam_adjusted"
])
elif
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
))
or
is_pure_tensor
(
inpt
):
elif
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
))
or
is_pure_tensor
(
inpt
):
self
.
_check_image_or_video
(
inpt
,
batch_size
=
params
[
"batch_size"
])
x1
,
y1
,
x2
,
y2
=
params
[
"box"
]
...
...
@@ -317,8 +317,8 @@ class CutMix(_BaseMixUpCutMix):
output
=
inpt
.
clone
()
output
[...,
y1
:
y2
,
x1
:
x2
]
=
rolled
[...,
y1
:
y2
,
x1
:
x2
]
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
)):
output
=
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
if
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
)):
output
=
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
return
output
else
:
...
...
torchvision/transforms/v2/_auto_augment.py
View file @
d5f4cc38
...
...
@@ -5,7 +5,7 @@ import PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
,
TreeSpec
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
transforms
as
_transforms
,
tv_tensors
from
torchvision.transforms
import
_functional_tensor
as
_FT
from
torchvision.transforms.v2
import
AutoAugmentPolicy
,
functional
as
F
,
InterpolationMode
,
Transform
from
torchvision.transforms.v2.functional._geometry
import
_check_interpolation
...
...
@@ -15,7 +15,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from
._utils
import
_get_fill
,
_setup_fill_arg
,
check_type
,
is_pure_tensor
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
datapoint
s
.
Video
]
ImageOrVideo
=
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
tv_tensors
.
Image
,
tv_tensor
s
.
Video
]
class
_AutoAugmentBase
(
Transform
):
...
...
@@ -46,7 +46,7 @@ class _AutoAugmentBase(Transform):
def
_flatten_and_extract_image_or_video
(
self
,
inputs
:
Any
,
unsupported_types
:
Tuple
[
Type
,
...]
=
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
),
unsupported_types
:
Tuple
[
Type
,
...]
=
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
),
)
->
Tuple
[
Tuple
[
List
[
Any
],
TreeSpec
,
int
],
ImageOrVideo
]:
flat_inputs
,
spec
=
tree_flatten
(
inputs
if
len
(
inputs
)
>
1
else
inputs
[
0
])
needs_transform_list
=
self
.
_needs_transform_list
(
flat_inputs
)
...
...
@@ -56,10 +56,10 @@ class _AutoAugmentBase(Transform):
if
needs_transform
and
check_type
(
inpt
,
(
datapoint
s
.
Image
,
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
is_pure_tensor
,
datapoint
s
.
Video
,
tv_tensor
s
.
Video
,
),
):
image_or_videos
.
append
((
idx
,
inpt
))
...
...
@@ -590,7 +590,7 @@ class AugMix(_AutoAugmentBase):
augmentation_space
=
self
.
_AUGMENTATION_SPACE
if
self
.
all_ops
else
self
.
_PARTIAL_AUGMENTATION_SPACE
orig_dims
=
list
(
image_or_video
.
shape
)
expected_ndim
=
5
if
isinstance
(
orig_image_or_video
,
datapoint
s
.
Video
)
else
4
expected_ndim
=
5
if
isinstance
(
orig_image_or_video
,
tv_tensor
s
.
Video
)
else
4
batch
=
image_or_video
.
reshape
([
1
]
*
max
(
expected_ndim
-
image_or_video
.
ndim
,
0
)
+
orig_dims
)
batch_dims
=
[
batch
.
size
(
0
)]
+
[
1
]
*
(
batch
.
ndim
-
1
)
...
...
@@ -627,8 +627,8 @@ class AugMix(_AutoAugmentBase):
mix
.
add_
(
combined_weights
[:,
i
].
reshape
(
batch_dims
)
*
aug
)
mix
=
mix
.
reshape
(
orig_dims
).
to
(
dtype
=
image_or_video
.
dtype
)
if
isinstance
(
orig_image_or_video
,
(
datapoints
.
Image
,
datapoint
s
.
Video
)):
mix
=
datapoint
s
.
wrap
(
mix
,
like
=
orig_image_or_video
)
if
isinstance
(
orig_image_or_video
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
)):
mix
=
tv_tensor
s
.
wrap
(
mix
,
like
=
orig_image_or_video
)
elif
isinstance
(
orig_image_or_video
,
PIL
.
Image
.
Image
):
mix
=
F
.
to_pil_image
(
mix
)
...
...
torchvision/transforms/v2/_geometry.py
View file @
d5f4cc38
...
...
@@ -6,7 +6,7 @@ from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence,
import
PIL.Image
import
torch
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
transforms
as
_transforms
,
tv_tensors
from
torchvision.ops.boxes
import
box_iou
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
,
InterpolationMode
,
Transform
...
...
@@ -36,8 +36,8 @@ class RandomHorizontalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomHorizontalFlip transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -56,8 +56,8 @@ class RandomVerticalFlip(_RandomApplyTransform):
.. v2betastatus:: RandomVerticalFlip transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -76,8 +76,8 @@ class Resize(Transform):
.. v2betastatus:: Resize transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -171,8 +171,8 @@ class CenterCrop(Transform):
.. v2betastatus:: CenterCrop transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -199,8 +199,8 @@ class RandomResizedCrop(Transform):
.. v2betastatus:: RandomResizedCrop transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -322,8 +322,8 @@ class FiveCrop(Transform):
.. v2betastatus:: FiveCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.
datapoint
s.Image` or a
:class:`~torchvision.
datapoint
s.Video` it can have arbitrary number of leading batch dimensions.
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.
tv_tensor
s.Image` or a
:class:`~torchvision.
tv_tensor
s.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape.
.. Note::
...
...
@@ -338,15 +338,15 @@ class FiveCrop(Transform):
Example:
>>> class BatchMultiCrop(transforms.Transform):
... def forward(self, sample: Tuple[Tuple[Union[
datapoints.Image, datapoint
s.Video], ...], int]):
... def forward(self, sample: Tuple[Tuple[Union[
tv_tensors.Image, tv_tensor
s.Video], ...], int]):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos =
datapoint
s.wrap(torch.stack(images_or_videos), like=image_or_video)
... images_or_videos =
tv_tensor
s.wrap(torch.stack(images_or_videos), like=image_or_video)
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels
...
>>> image =
datapoint
s.Image(torch.rand(3, 256, 256))
>>> image =
tv_tensor
s.Image(torch.rand(3, 256, 256))
>>> label = 3
>>> transform = transforms.Compose([transforms.FiveCrop(224), BatchMultiCrop()])
>>> images, labels = transform(image, label)
...
...
@@ -363,10 +363,10 @@ class FiveCrop(Transform):
self
.
size
=
_setup_size
(
size
,
error_msg
=
"Please provide only two dimensions (h, w) for size."
)
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)):
if
isinstance
(
inpt
,
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"
datapoint
s.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
f
"
tv_tensor
s.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
...
...
@@ -374,7 +374,7 @@ class FiveCrop(Transform):
return
self
.
_call_kernel
(
F
.
five_crop
,
inpt
,
self
.
size
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
):
if
has_any
(
flat_inputs
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
):
raise
TypeError
(
f
"BoundingBoxes'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
...
...
@@ -384,8 +384,8 @@ class TenCrop(Transform):
.. v2betastatus:: TenCrop transform
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.
datapoint
s.Image` or a
:class:`~torchvision.
datapoint
s.Video` it can have arbitrary number of leading batch dimensions.
If the input is a :class:`torch.Tensor` or a :class:`~torchvision.
tv_tensor
s.Image` or a
:class:`~torchvision.
tv_tensor
s.Video` it can have arbitrary number of leading batch dimensions.
For example, the image can have ``[..., C, H, W]`` shape.
See :class:`~torchvision.transforms.v2.FiveCrop` for an example.
...
...
@@ -410,15 +410,15 @@ class TenCrop(Transform):
self
.
vertical_flip
=
vertical_flip
def
_call_kernel
(
self
,
functional
:
Callable
,
inpt
:
Any
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
if
isinstance
(
inpt
,
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
)):
if
isinstance
(
inpt
,
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
)):
warnings
.
warn
(
f
"
{
type
(
self
).
__name__
}
() is currently passing through inputs of type "
f
"
datapoint
s.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
f
"
tv_tensor
s.
{
type
(
inpt
).
__name__
}
. This will likely change in the future."
)
return
super
().
_call_kernel
(
functional
,
inpt
,
*
args
,
**
kwargs
)
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
has_any
(
flat_inputs
,
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
):
if
has_any
(
flat_inputs
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
):
raise
TypeError
(
f
"BoundingBoxes'es and Mask's are not supported by
{
type
(
self
).
__name__
}
()"
)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
...
...
@@ -430,8 +430,8 @@ class Pad(Transform):
.. v2betastatus:: Pad transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -447,7 +447,7 @@ class Pad(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is "constant".
...
...
@@ -515,8 +515,8 @@ class RandomZoomOut(_RandomApplyTransform):
output_width = input_width * r
output_height = input_height * r
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -524,7 +524,7 @@ class RandomZoomOut(_RandomApplyTransform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
side_range (sequence of floats, optional): tuple of two floats defines minimum and maximum factors to
scale the input size.
...
...
@@ -574,8 +574,8 @@ class RandomRotation(Transform):
.. v2betastatus:: RandomRotation transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -596,7 +596,7 @@ class RandomRotation(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
...
...
@@ -648,8 +648,8 @@ class RandomAffine(Transform):
.. v2betastatus:: RandomAffine transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -676,7 +676,7 @@ class RandomAffine(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
...
...
@@ -770,8 +770,8 @@ class RandomCrop(Transform):
.. v2betastatus:: RandomCrop transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -794,7 +794,7 @@ class RandomCrop(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is constant.
...
...
@@ -927,8 +927,8 @@ class RandomPerspective(_RandomApplyTransform):
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -943,7 +943,7 @@ class RandomPerspective(_RandomApplyTransform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
"""
...
...
@@ -1014,8 +1014,8 @@ class ElasticTransform(Transform):
.. v2betastatus:: RandomPerspective transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -1046,7 +1046,7 @@ class ElasticTransform(Transform):
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={
datapoint
s.Image: 127,
datapoint
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``fill={
tv_tensor
s.Image: 127,
tv_tensor
s.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
"""
...
...
@@ -1107,15 +1107,15 @@ class RandomIoUCrop(Transform):
.. v2betastatus:: RandomIoUCrop transform
This transformation requires an image or video data and ``
datapoint
s.BoundingBoxes`` in the input.
This transformation requires an image or video data and ``
tv_tensor
s.BoundingBoxes`` in the input.
.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -1152,8 +1152,8 @@ class RandomIoUCrop(Transform):
def
_check_inputs
(
self
,
flat_inputs
:
List
[
Any
])
->
None
:
if
not
(
has_all
(
flat_inputs
,
datapoint
s
.
BoundingBoxes
)
and
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
datapoint
s
.
Image
,
is_pure_tensor
)
has_all
(
flat_inputs
,
tv_tensor
s
.
BoundingBoxes
)
and
has_any
(
flat_inputs
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Image
,
is_pure_tensor
)
):
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
() requires input sample to contain tensor or PIL images "
...
...
@@ -1193,7 +1193,7 @@ class RandomIoUCrop(Transform):
xyxy_bboxes
=
F
.
convert_bounding_box_format
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
format
,
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
)
cx
=
0.5
*
(
xyxy_bboxes
[...,
0
]
+
xyxy_bboxes
[...,
2
])
cy
=
0.5
*
(
xyxy_bboxes
[...,
1
]
+
xyxy_bboxes
[...,
3
])
...
...
@@ -1221,7 +1221,7 @@ class RandomIoUCrop(Transform):
F
.
crop
,
inpt
,
top
=
params
[
"top"
],
left
=
params
[
"left"
],
height
=
params
[
"height"
],
width
=
params
[
"width"
]
)
if
isinstance
(
output
,
datapoint
s
.
BoundingBoxes
):
if
isinstance
(
output
,
tv_tensor
s
.
BoundingBoxes
):
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBoxes()
output
[
~
params
[
"is_within_crop_area"
]]
=
0
...
...
@@ -1235,8 +1235,8 @@ class ScaleJitter(Transform):
.. v2betastatus:: ScaleJitter transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -1303,8 +1303,8 @@ class RandomShortestSize(Transform):
.. v2betastatus:: RandomShortestSize transform
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
@@ -1384,8 +1384,8 @@ class RandomResize(Transform):
output_width = size
output_height = size
If the input is a :class:`torch.Tensor` or a ``
Datapoint
`` (e.g. :class:`~torchvision.
datapoint
s.Image`,
:class:`~torchvision.
datapoint
s.Video`, :class:`~torchvision.
datapoint
s.BoundingBoxes` etc.)
If the input is a :class:`torch.Tensor` or a ``
TVTensor
`` (e.g. :class:`~torchvision.
tv_tensor
s.Image`,
:class:`~torchvision.
tv_tensor
s.Video`, :class:`~torchvision.
tv_tensor
s.BoundingBoxes` etc.)
it can have arbitrary number of leading batch dimensions. For example,
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
...
...
torchvision/transforms/v2/_meta.py
View file @
d5f4cc38
from
typing
import
Any
,
Dict
,
Union
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
...
...
@@ -10,20 +10,20 @@ class ConvertBoundingBoxFormat(Transform):
.. v2betastatus:: ConvertBoundingBoxFormat transform
Args:
format (str or
datapoint
s.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.
datapoint
s.BoundingBoxFormat` and
format (str or
tv_tensor
s.BoundingBoxFormat): output bounding box format.
Possible values are defined by :class:`~torchvision.
tv_tensor
s.BoundingBoxFormat` and
string values match the enums, e.g. "XYXY" or "XYWH" etc.
"""
_transformed_types
=
(
datapoint
s
.
BoundingBoxes
,)
_transformed_types
=
(
tv_tensor
s
.
BoundingBoxes
,)
def
__init__
(
self
,
format
:
Union
[
str
,
datapoint
s
.
BoundingBoxFormat
])
->
None
:
def
__init__
(
self
,
format
:
Union
[
str
,
tv_tensor
s
.
BoundingBoxFormat
])
->
None
:
super
().
__init__
()
if
isinstance
(
format
,
str
):
format
=
datapoint
s
.
BoundingBoxFormat
[
format
]
format
=
tv_tensor
s
.
BoundingBoxFormat
[
format
]
self
.
format
=
format
def
_transform
(
self
,
inpt
:
datapoint
s
.
BoundingBoxes
,
params
:
Dict
[
str
,
Any
])
->
datapoint
s
.
BoundingBoxes
:
def
_transform
(
self
,
inpt
:
tv_tensor
s
.
BoundingBoxes
,
params
:
Dict
[
str
,
Any
])
->
tv_tensor
s
.
BoundingBoxes
:
return
F
.
convert_bounding_box_format
(
inpt
,
new_format
=
self
.
format
)
# type: ignore[return-value]
...
...
@@ -36,7 +36,7 @@ class ClampBoundingBoxes(Transform):
"""
_transformed_types
=
(
datapoint
s
.
BoundingBoxes
,)
_transformed_types
=
(
tv_tensor
s
.
BoundingBoxes
,)
def
_transform
(
self
,
inpt
:
datapoint
s
.
BoundingBoxes
,
params
:
Dict
[
str
,
Any
])
->
datapoint
s
.
BoundingBoxes
:
def
_transform
(
self
,
inpt
:
tv_tensor
s
.
BoundingBoxes
,
params
:
Dict
[
str
,
Any
])
->
tv_tensor
s
.
BoundingBoxes
:
return
F
.
clamp_bounding_boxes
(
inpt
)
# type: ignore[return-value]
torchvision/transforms/v2/_misc.py
View file @
d5f4cc38
...
...
@@ -6,7 +6,7 @@ import PIL.Image
import
torch
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoints
,
transforms
as
_transforms
from
torchvision
import
transforms
as
_transforms
,
tv_tensors
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
,
get_bounding_boxes
,
has_any
,
is_pure_tensor
...
...
@@ -74,7 +74,7 @@ class LinearTransformation(Transform):
_v1_transform_cls
=
_transforms
.
LinearTransformation
_transformed_types
=
(
is_pure_tensor
,
datapoints
.
Image
,
datapoint
s
.
Video
)
_transformed_types
=
(
is_pure_tensor
,
tv_tensors
.
Image
,
tv_tensor
s
.
Video
)
def
__init__
(
self
,
transformation_matrix
:
torch
.
Tensor
,
mean_vector
:
torch
.
Tensor
):
super
().
__init__
()
...
...
@@ -129,8 +129,8 @@ class LinearTransformation(Transform):
output
=
torch
.
mm
(
flat_inpt
,
transformation_matrix
)
output
=
output
.
reshape
(
shape
)
if
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
)):
output
=
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
if
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
)):
output
=
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
return
output
...
...
@@ -227,12 +227,12 @@ class ToDtype(Transform):
``ToDtype(dtype, scale=True)`` is the recommended replacement for ``ConvertImageDtype(dtype)``.
Args:
dtype (``torch.dtype`` or dict of ``
Datapoint
`` -> ``torch.dtype``): The dtype to convert to.
dtype (``torch.dtype`` or dict of ``
TVTensor
`` -> ``torch.dtype``): The dtype to convert to.
If a ``torch.dtype`` is passed, e.g. ``torch.float32``, only images and videos will be converted
to that dtype: this is for compatibility with :class:`~torchvision.transforms.v2.ConvertImageDtype`.
A dict can be passed to specify per-
datapoint
conversions, e.g.
``dtype={
datapoint
s.Image: torch.float32,
datapoint
s.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other
datapoint
type, and ``None`` means no conversion.
A dict can be passed to specify per-
tv_tensor
conversions, e.g.
``dtype={
tv_tensor
s.Image: torch.float32,
tv_tensor
s.Mask: torch.int64, "others":None}``. The "others"
key can be used as a catch-all for any other
tv_tensor
type, and ``None`` means no conversion.
scale (bool, optional): Whether to scale the values for images or videos. See :ref:`range_and_dtype`.
Default: ``False``.
"""
...
...
@@ -250,12 +250,12 @@ class ToDtype(Transform):
if
(
isinstance
(
dtype
,
dict
)
and
torch
.
Tensor
in
dtype
and
any
(
cls
in
dtype
for
cls
in
[
datapoints
.
Image
,
datapoint
s
.
Video
])
and
any
(
cls
in
dtype
for
cls
in
[
tv_tensors
.
Image
,
tv_tensor
s
.
Video
])
):
warnings
.
warn
(
"Got `dtype` values for `torch.Tensor` and either `
datapoint
s.Image` or `
datapoint
s.Video`. "
"Got `dtype` values for `torch.Tensor` and either `
tv_tensor
s.Image` or `
tv_tensor
s.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `
datapoint
s.Image` or `
datapoint
s.Video` is present in the input."
"in case a `
tv_tensor
s.Image` or `
tv_tensor
s.Video` is present in the input."
)
self
.
dtype
=
dtype
self
.
scale
=
scale
...
...
@@ -264,7 +264,7 @@ class ToDtype(Transform):
if
isinstance
(
self
.
dtype
,
torch
.
dtype
):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if
not
is_pure_tensor
(
inpt
)
and
not
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
)):
if
not
is_pure_tensor
(
inpt
)
and
not
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
)):
return
inpt
dtype
:
Optional
[
torch
.
dtype
]
=
self
.
dtype
...
...
@@ -278,10 +278,10 @@ class ToDtype(Transform):
"If you only need to convert the dtype of images or videos, you can just pass e.g. dtype=torch.float32. "
"If you're passing a dict as dtype, "
'you can use "others" as a catch-all key '
'e.g. dtype={
datapoint
s.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
'e.g. dtype={
tv_tensor
s.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
supports_scaling
=
is_pure_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
datapoints
.
Image
,
datapoint
s
.
Video
))
supports_scaling
=
is_pure_tensor
(
inpt
)
or
isinstance
(
inpt
,
(
tv_tensors
.
Image
,
tv_tensor
s
.
Video
))
if
dtype
is
None
:
if
self
.
scale
and
supports_scaling
:
warnings
.
warn
(
...
...
@@ -389,10 +389,10 @@ class SanitizeBoundingBoxes(Transform):
)
boxes
=
cast
(
datapoint
s
.
BoundingBoxes
,
tv_tensor
s
.
BoundingBoxes
,
F
.
convert_bounding_box_format
(
boxes
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
),
)
ws
,
hs
=
boxes
[:,
2
]
-
boxes
[:,
0
],
boxes
[:,
3
]
-
boxes
[:,
1
]
...
...
@@ -415,7 +415,7 @@ class SanitizeBoundingBoxes(Transform):
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
Any
:
is_label
=
inpt
is
not
None
and
inpt
is
params
[
"labels"
]
is_bounding_boxes_or_mask
=
isinstance
(
inpt
,
(
datapoint
s
.
BoundingBoxes
,
datapoint
s
.
Mask
))
is_bounding_boxes_or_mask
=
isinstance
(
inpt
,
(
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
s
.
Mask
))
if
not
(
is_label
or
is_bounding_boxes_or_mask
):
return
inpt
...
...
@@ -425,4 +425,4 @@ class SanitizeBoundingBoxes(Transform):
if
is_label
:
return
output
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
torchvision/transforms/v2/_transform.py
View file @
d5f4cc38
...
...
@@ -7,7 +7,7 @@ import PIL.Image
import
torch
from
torch
import
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.v2._utils
import
check_type
,
has_any
,
is_pure_tensor
from
torchvision.utils
import
_log_api_usage_once
...
...
@@ -56,8 +56,8 @@ class Transform(nn.Module):
def
_needs_transform_list
(
self
,
flat_inputs
:
List
[
Any
])
->
List
[
bool
]:
# Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a
datapoint
, are passed through if there is an explicit image
# (`
datapoint
s.Image` or `PIL.Image.Image`) or video (`
datapoint
s.Video`) in the sample.
# 1. Pure tensors, i.e. tensors that are not a
tv_tensor
, are passed through if there is an explicit image
# (`
tv_tensor
s.Image` or `PIL.Image.Image`) or video (`
tv_tensor
s.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
...
...
@@ -72,7 +72,7 @@ class Transform(nn.Module):
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list
=
[]
transform_pure_tensor
=
not
has_any
(
flat_inputs
,
datapoints
.
Image
,
datapoint
s
.
Video
,
PIL
.
Image
.
Image
)
transform_pure_tensor
=
not
has_any
(
flat_inputs
,
tv_tensors
.
Image
,
tv_tensor
s
.
Video
,
PIL
.
Image
.
Image
)
for
inpt
in
flat_inputs
:
needs_transform
=
True
...
...
torchvision/transforms/v2/_type_conversion.py
View file @
d5f4cc38
...
...
@@ -4,7 +4,7 @@ import numpy as np
import
PIL.Image
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
is_pure_tensor
...
...
@@ -27,7 +27,7 @@ class PILToTensor(Transform):
class
ToImage
(
Transform
):
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.
datapoint
s.Image`
"""[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.
tv_tensor
s.Image`
; this does not scale values.
.. v2betastatus:: ToImage transform
...
...
@@ -39,7 +39,7 @@ class ToImage(Transform):
def
_transform
(
self
,
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
],
params
:
Dict
[
str
,
Any
]
)
->
datapoint
s
.
Image
:
)
->
tv_tensor
s
.
Image
:
return
F
.
to_image
(
inpt
)
...
...
@@ -66,7 +66,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
_transformed_types
=
(
is_pure_tensor
,
datapoint
s
.
Image
,
np
.
ndarray
)
_transformed_types
=
(
is_pure_tensor
,
tv_tensor
s
.
Image
,
np
.
ndarray
)
def
__init__
(
self
,
mode
:
Optional
[
str
]
=
None
)
->
None
:
super
().
__init__
()
...
...
@@ -79,14 +79,14 @@ class ToPILImage(Transform):
class
ToPureTensor
(
Transform
):
"""[BETA] Convert all
datapoint
s to pure tensors, removing associated metadata (if any).
"""[BETA] Convert all
tv_tensor
s to pure tensors, removing associated metadata (if any).
.. v2betastatus:: ToPureTensor transform
This doesn't scale or change the values, only the type.
"""
_transformed_types
=
(
datapoints
.
Datapoint
,)
_transformed_types
=
(
tv_tensors
.
TVTensor
,)
def
_transform
(
self
,
inpt
:
Any
,
params
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
return
inpt
.
as_subclass
(
torch
.
Tensor
)
torchvision/transforms/v2/_utils.py
View file @
d5f4cc38
...
...
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
import
PIL.Image
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision._utils
import
sequence_to_str
...
...
@@ -149,10 +149,10 @@ def _parse_labels_getter(
raise
ValueError
(
f
"labels_getter should either be 'default', a callable, or None, but got
{
labels_getter
}
."
)
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoint
s
.
BoundingBoxes
:
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
tv_tensor
s
.
BoundingBoxes
:
# This assumes there is only one bbox per sample as per the general convention
try
:
return
next
(
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoint
s
.
BoundingBoxes
))
return
next
(
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
tv_tensor
s
.
BoundingBoxes
))
except
StopIteration
:
raise
ValueError
(
"No bounding boxes were found in the sample"
)
...
...
@@ -161,7 +161,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws
=
{
tuple
(
get_dimensions
(
inpt
))
for
inpt
in
flat_inputs
if
check_type
(
inpt
,
(
is_pure_tensor
,
datapoint
s
.
Image
,
PIL
.
Image
.
Image
,
datapoint
s
.
Video
))
if
check_type
(
inpt
,
(
is_pure_tensor
,
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
tv_tensor
s
.
Video
))
}
if
not
chws
:
raise
TypeError
(
"No image or video was found in the sample"
)
...
...
@@ -179,11 +179,11 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
inpt
,
(
is_pure_tensor
,
datapoint
s
.
Image
,
tv_tensor
s
.
Image
,
PIL
.
Image
.
Image
,
datapoint
s
.
Video
,
datapoint
s
.
Mask
,
datapoint
s
.
BoundingBoxes
,
tv_tensor
s
.
Video
,
tv_tensor
s
.
Mask
,
tv_tensor
s
.
BoundingBoxes
,
),
)
}
...
...
torchvision/transforms/v2/functional/_augment.py
View file @
d5f4cc38
import
PIL.Image
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
from
torchvision.utils
import
_log_api_usage_once
...
...
@@ -28,7 +28,7 @@ def erase(
@
_register_kernel_internal
(
erase
,
torch
.
Tensor
)
@
_register_kernel_internal
(
erase
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
erase
,
tv_tensor
s
.
Image
)
def
erase_image
(
image
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
@@ -48,7 +48,7 @@ def _erase_image_pil(
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
@
_register_kernel_internal
(
erase
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
erase
,
tv_tensor
s
.
Video
)
def
erase_video
(
video
:
torch
.
Tensor
,
i
:
int
,
j
:
int
,
h
:
int
,
w
:
int
,
v
:
torch
.
Tensor
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
...
...
torchvision/transforms/v2/functional/_color.py
View file @
d5f4cc38
...
...
@@ -3,7 +3,7 @@ from typing import List
import
PIL.Image
import
torch
from
torch.nn.functional
import
conv2d
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms
import
_functional_pil
as
_FP
from
torchvision.transforms._functional_tensor
import
_max_value
...
...
@@ -47,7 +47,7 @@ def _rgb_to_grayscale_image(
@
_register_kernel_internal
(
rgb_to_grayscale
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rgb_to_grayscale
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
rgb_to_grayscale
,
tv_tensor
s
.
Image
)
def
rgb_to_grayscale_image
(
image
:
torch
.
Tensor
,
num_output_channels
:
int
=
1
)
->
torch
.
Tensor
:
if
num_output_channels
not
in
(
1
,
3
):
raise
ValueError
(
f
"num_output_channels must be 1 or 3, got
{
num_output_channels
}
."
)
...
...
@@ -82,7 +82,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten
@
_register_kernel_internal
(
adjust_brightness
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
adjust_brightness
,
tv_tensor
s
.
Image
)
def
adjust_brightness_image
(
image
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
if
brightness_factor
<
0
:
raise
ValueError
(
f
"brightness_factor (
{
brightness_factor
}
) is not non-negative."
)
...
...
@@ -102,7 +102,7 @@ def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: floa
return
_FP
.
adjust_brightness
(
image
,
brightness_factor
=
brightness_factor
)
@
_register_kernel_internal
(
adjust_brightness
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
adjust_brightness
,
tv_tensor
s
.
Video
)
def
adjust_brightness_video
(
video
:
torch
.
Tensor
,
brightness_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_brightness_image
(
video
,
brightness_factor
=
brightness_factor
)
...
...
@@ -119,7 +119,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten
@
_register_kernel_internal
(
adjust_saturation
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
adjust_saturation
,
tv_tensor
s
.
Image
)
def
adjust_saturation_image
(
image
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
if
saturation_factor
<
0
:
raise
ValueError
(
f
"saturation_factor (
{
saturation_factor
}
) is not non-negative."
)
...
...
@@ -141,7 +141,7 @@ def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> to
_adjust_saturation_image_pil
=
_register_kernel_internal
(
adjust_saturation
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_saturation
)
@
_register_kernel_internal
(
adjust_saturation
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
adjust_saturation
,
tv_tensor
s
.
Video
)
def
adjust_saturation_video
(
video
:
torch
.
Tensor
,
saturation_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_saturation_image
(
video
,
saturation_factor
=
saturation_factor
)
...
...
@@ -158,7 +158,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
@
_register_kernel_internal
(
adjust_contrast
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
adjust_contrast
,
tv_tensor
s
.
Image
)
def
adjust_contrast_image
(
image
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
if
contrast_factor
<
0
:
raise
ValueError
(
f
"contrast_factor (
{
contrast_factor
}
) is not non-negative."
)
...
...
@@ -180,7 +180,7 @@ def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.
_adjust_contrast_image_pil
=
_register_kernel_internal
(
adjust_contrast
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_contrast
)
@
_register_kernel_internal
(
adjust_contrast
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
adjust_contrast
,
tv_tensor
s
.
Video
)
def
adjust_contrast_video
(
video
:
torch
.
Tensor
,
contrast_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_contrast_image
(
video
,
contrast_factor
=
contrast_factor
)
...
...
@@ -197,7 +197,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso
@
_register_kernel_internal
(
adjust_sharpness
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
adjust_sharpness
,
tv_tensor
s
.
Image
)
def
adjust_sharpness_image
(
image
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
num_channels
,
height
,
width
=
image
.
shape
[
-
3
:]
if
num_channels
not
in
(
1
,
3
):
...
...
@@ -253,7 +253,7 @@ def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torc
_adjust_sharpness_image_pil
=
_register_kernel_internal
(
adjust_sharpness
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_sharpness
)
@
_register_kernel_internal
(
adjust_sharpness
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
adjust_sharpness
,
tv_tensor
s
.
Video
)
def
adjust_sharpness_video
(
video
:
torch
.
Tensor
,
sharpness_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_sharpness_image
(
video
,
sharpness_factor
=
sharpness_factor
)
...
...
@@ -340,7 +340,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
adjust_hue
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_hue
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
adjust_hue
,
tv_tensor
s
.
Image
)
def
adjust_hue_image
(
image
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
f
"hue_factor (
{
hue_factor
}
) is not in [-0.5, 0.5]."
)
...
...
@@ -371,7 +371,7 @@ def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
_adjust_hue_image_pil
=
_register_kernel_internal
(
adjust_hue
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_hue
)
@
_register_kernel_internal
(
adjust_hue
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
adjust_hue
,
tv_tensor
s
.
Video
)
def
adjust_hue_video
(
video
:
torch
.
Tensor
,
hue_factor
:
float
)
->
torch
.
Tensor
:
return
adjust_hue_image
(
video
,
hue_factor
=
hue_factor
)
...
...
@@ -388,7 +388,7 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten
@
_register_kernel_internal
(
adjust_gamma
,
torch
.
Tensor
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
adjust_gamma
,
tv_tensor
s
.
Image
)
def
adjust_gamma_image
(
image
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1.0
)
->
torch
.
Tensor
:
if
gamma
<
0
:
raise
ValueError
(
"Gamma should be a non-negative real number"
)
...
...
@@ -411,7 +411,7 @@ def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) ->
_adjust_gamma_image_pil
=
_register_kernel_internal
(
adjust_gamma
,
PIL
.
Image
.
Image
)(
_FP
.
adjust_gamma
)
@
_register_kernel_internal
(
adjust_gamma
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
adjust_gamma
,
tv_tensor
s
.
Video
)
def
adjust_gamma_video
(
video
:
torch
.
Tensor
,
gamma
:
float
,
gain
:
float
=
1
)
->
torch
.
Tensor
:
return
adjust_gamma_image
(
video
,
gamma
=
gamma
,
gain
=
gain
)
...
...
@@ -428,7 +428,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@
_register_kernel_internal
(
posterize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
posterize
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
posterize
,
tv_tensor
s
.
Image
)
def
posterize_image
(
image
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
levels
=
1
<<
bits
...
...
@@ -445,7 +445,7 @@ def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
_posterize_image_pil
=
_register_kernel_internal
(
posterize
,
PIL
.
Image
.
Image
)(
_FP
.
posterize
)
@
_register_kernel_internal
(
posterize
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
posterize
,
tv_tensor
s
.
Video
)
def
posterize_video
(
video
:
torch
.
Tensor
,
bits
:
int
)
->
torch
.
Tensor
:
return
posterize_image
(
video
,
bits
=
bits
)
...
...
@@ -462,7 +462,7 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
@
_register_kernel_internal
(
solarize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
solarize
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
solarize
,
tv_tensor
s
.
Image
)
def
solarize_image
(
image
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
if
threshold
>
_max_value
(
image
.
dtype
):
raise
TypeError
(
f
"Threshold should be less or equal the maximum value of the dtype, but got
{
threshold
}
"
)
...
...
@@ -473,7 +473,7 @@ def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor:
_solarize_image_pil
=
_register_kernel_internal
(
solarize
,
PIL
.
Image
.
Image
)(
_FP
.
solarize
)
@
_register_kernel_internal
(
solarize
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
solarize
,
tv_tensor
s
.
Video
)
def
solarize_video
(
video
:
torch
.
Tensor
,
threshold
:
float
)
->
torch
.
Tensor
:
return
solarize_image
(
video
,
threshold
=
threshold
)
...
...
@@ -490,7 +490,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
autocontrast
,
torch
.
Tensor
)
@
_register_kernel_internal
(
autocontrast
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
autocontrast
,
tv_tensor
s
.
Image
)
def
autocontrast_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
c
=
image
.
shape
[
-
3
]
if
c
not
in
[
1
,
3
]:
...
...
@@ -523,7 +523,7 @@ def autocontrast_image(image: torch.Tensor) -> torch.Tensor:
_autocontrast_image_pil
=
_register_kernel_internal
(
autocontrast
,
PIL
.
Image
.
Image
)(
_FP
.
autocontrast
)
@
_register_kernel_internal
(
autocontrast
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
autocontrast
,
tv_tensor
s
.
Video
)
def
autocontrast_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
autocontrast_image
(
video
)
...
...
@@ -540,7 +540,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
equalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
equalize
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
equalize
,
tv_tensor
s
.
Image
)
def
equalize_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
numel
()
==
0
:
return
image
...
...
@@ -613,7 +613,7 @@ def equalize_image(image: torch.Tensor) -> torch.Tensor:
_equalize_image_pil
=
_register_kernel_internal
(
equalize
,
PIL
.
Image
.
Image
)(
_FP
.
equalize
)
@
_register_kernel_internal
(
equalize
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
equalize
,
tv_tensor
s
.
Video
)
def
equalize_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
equalize_image
(
video
)
...
...
@@ -630,7 +630,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
invert
,
torch
.
Tensor
)
@
_register_kernel_internal
(
invert
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
invert
,
tv_tensor
s
.
Image
)
def
invert_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
image
.
is_floating_point
():
return
1.0
-
image
...
...
@@ -644,7 +644,7 @@ def invert_image(image: torch.Tensor) -> torch.Tensor:
_invert_image_pil
=
_register_kernel_internal
(
invert
,
PIL
.
Image
.
Image
)(
_FP
.
invert
)
@
_register_kernel_internal
(
invert
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
invert
,
tv_tensor
s
.
Video
)
def
invert_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
invert_image
(
video
)
...
...
@@ -653,7 +653,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
"""Permute the channels of the input according to the given permutation.
This function supports plain :class:`~torch.Tensor`'s, :class:`PIL.Image.Image`'s, and
:class:`torchvision.
datapoint
s.Image` and :class:`torchvision.
datapoint
s.Video`.
:class:`torchvision.
tv_tensor
s.Image` and :class:`torchvision.
tv_tensor
s.Video`.
Example:
>>> rgb_image = torch.rand(3, 256, 256)
...
...
@@ -681,7 +681,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor
@
_register_kernel_internal
(
permute_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
permute_channels
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
permute_channels
,
tv_tensor
s
.
Image
)
def
permute_channels_image
(
image
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
shape
=
image
.
shape
num_channels
,
height
,
width
=
shape
[
-
3
:]
...
...
@@ -704,6 +704,6 @@ def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int])
return
to_pil_image
(
permute_channels_image
(
pil_to_tensor
(
image
),
permutation
=
permutation
))
@
_register_kernel_internal
(
permute_channels
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
permute_channels
,
tv_tensor
s
.
Video
)
def
permute_channels_video
(
video
:
torch
.
Tensor
,
permutation
:
List
[
int
])
->
torch
.
Tensor
:
return
permute_channels_image
(
video
,
permutation
=
permutation
)
torchvision/transforms/v2/functional/_geometry.py
View file @
d5f4cc38
...
...
@@ -7,7 +7,7 @@ import PIL.Image
import
torch
from
torch.nn.functional
import
grid_sample
,
interpolate
,
pad
as
torch_pad
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms
import
_functional_pil
as
_FP
from
torchvision.transforms._functional_tensor
import
_pad_symmetric
from
torchvision.transforms.functional
import
(
...
...
@@ -51,7 +51,7 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
horizontal_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
horizontal_flip
,
tv_tensor
s
.
Image
)
def
horizontal_flip_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
1
)
...
...
@@ -61,37 +61,37 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return
_FP
.
hflip
(
image
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoint
s
.
Mask
)
@
_register_kernel_internal
(
horizontal_flip
,
tv_tensor
s
.
Mask
)
def
horizontal_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
horizontal_flip_image
(
mask
)
def
horizontal_flip_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
]
bounding_boxes
:
torch
.
Tensor
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
shape
=
bounding_boxes
.
shape
bounding_boxes
=
bounding_boxes
.
clone
().
reshape
(
-
1
,
4
)
if
format
==
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
bounding_boxes
[:,
[
2
,
0
]]
=
bounding_boxes
[:,
[
0
,
2
]].
sub_
(
canvas_size
[
1
]).
neg_
()
elif
format
==
datapoint
s
.
BoundingBoxFormat
.
XYWH
:
elif
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
:
bounding_boxes
[:,
0
].
add_
(
bounding_boxes
[:,
2
]).
sub_
(
canvas_size
[
1
]).
neg_
()
else
:
# format ==
datapoint
s.BoundingBoxFormat.CXCYWH:
else
:
# format ==
tv_tensor
s.BoundingBoxFormat.CXCYWH:
bounding_boxes
[:,
0
].
sub_
(
canvas_size
[
1
]).
neg_
()
return
bounding_boxes
.
reshape
(
shape
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
def
_horizontal_flip_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
)
->
datapoint
s
.
BoundingBoxes
:
@
_register_kernel_internal
(
horizontal_flip
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_horizontal_flip_bounding_boxes_dispatch
(
inpt
:
tv_tensor
s
.
BoundingBoxes
)
->
tv_tensor
s
.
BoundingBoxes
:
output
=
horizontal_flip_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
horizontal_flip
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
horizontal_flip
,
tv_tensor
s
.
Video
)
def
horizontal_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
horizontal_flip_image
(
video
)
...
...
@@ -108,7 +108,7 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
@
_register_kernel_internal
(
vertical_flip
,
torch
.
Tensor
)
@
_register_kernel_internal
(
vertical_flip
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
vertical_flip
,
tv_tensor
s
.
Image
)
def
vertical_flip_image
(
image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
image
.
flip
(
-
2
)
...
...
@@ -118,37 +118,37 @@ def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return
_FP
.
vflip
(
image
)
@
_register_kernel_internal
(
vertical_flip
,
datapoint
s
.
Mask
)
@
_register_kernel_internal
(
vertical_flip
,
tv_tensor
s
.
Mask
)
def
vertical_flip_mask
(
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image
(
mask
)
def
vertical_flip_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
]
bounding_boxes
:
torch
.
Tensor
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
shape
=
bounding_boxes
.
shape
bounding_boxes
=
bounding_boxes
.
clone
().
reshape
(
-
1
,
4
)
if
format
==
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
bounding_boxes
[:,
[
1
,
3
]]
=
bounding_boxes
[:,
[
3
,
1
]].
sub_
(
canvas_size
[
0
]).
neg_
()
elif
format
==
datapoint
s
.
BoundingBoxFormat
.
XYWH
:
elif
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
:
bounding_boxes
[:,
1
].
add_
(
bounding_boxes
[:,
3
]).
sub_
(
canvas_size
[
0
]).
neg_
()
else
:
# format ==
datapoint
s.BoundingBoxFormat.CXCYWH:
else
:
# format ==
tv_tensor
s.BoundingBoxFormat.CXCYWH:
bounding_boxes
[:,
1
].
sub_
(
canvas_size
[
0
]).
neg_
()
return
bounding_boxes
.
reshape
(
shape
)
@
_register_kernel_internal
(
vertical_flip
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
def
_vertical_flip_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
)
->
datapoint
s
.
BoundingBoxes
:
@
_register_kernel_internal
(
vertical_flip
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_vertical_flip_bounding_boxes_dispatch
(
inpt
:
tv_tensor
s
.
BoundingBoxes
)
->
tv_tensor
s
.
BoundingBoxes
:
output
=
vertical_flip_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
vertical_flip
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
vertical_flip
,
tv_tensor
s
.
Video
)
def
vertical_flip_video
(
video
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vertical_flip_image
(
video
)
...
...
@@ -190,7 +190,7 @@ def resize(
@
_register_kernel_internal
(
resize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resize
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
resize
,
tv_tensor
s
.
Image
)
def
resize_image
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
...
...
@@ -319,12 +319,12 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return
output
@
_register_kernel_internal
(
resize
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
resize
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_resize_mask_dispatch
(
inpt
:
datapoint
s
.
Mask
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
)
->
datapoint
s
.
Mask
:
inpt
:
tv_tensor
s
.
Mask
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
)
->
tv_tensor
s
.
Mask
:
output
=
resize_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
size
,
max_size
=
max_size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
def
resize_bounding_boxes
(
...
...
@@ -345,17 +345,17 @@ def resize_bounding_boxes(
)
@
_register_kernel_internal
(
resize
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
resize
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_resize_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
size
:
List
[
int
],
max_size
:
Optional
[
int
]
=
None
,
**
kwargs
:
Any
)
->
tv_tensor
s
.
BoundingBoxes
:
output
,
canvas_size
=
resize_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
inpt
.
canvas_size
,
size
,
max_size
=
max_size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
resize
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
resize
,
tv_tensor
s
.
Video
)
def
resize_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
...
...
@@ -651,7 +651,7 @@ def _affine_grid(
@
_register_kernel_internal
(
affine
,
torch
.
Tensor
)
@
_register_kernel_internal
(
affine
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
affine
,
tv_tensor
s
.
Image
)
def
affine_image
(
image
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
...
...
@@ -730,7 +730,7 @@ def _affine_image_pil(
def
_affine_bounding_boxes_with_expand
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
...
...
@@ -749,7 +749,7 @@ def _affine_bounding_boxes_with_expand(
device
=
bounding_boxes
.
device
bounding_boxes
=
(
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bounding_boxes
,
old_format
=
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
).
reshape
(
-
1
,
4
)
...
...
@@ -808,9 +808,9 @@ def _affine_bounding_boxes_with_expand(
new_width
,
new_height
=
_compute_affine_output_size
(
affine_vector
,
width
,
height
)
canvas_size
=
(
new_height
,
new_width
)
out_bboxes
=
clamp_bounding_boxes
(
out_bboxes
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
)
out_bboxes
=
clamp_bounding_boxes
(
out_bboxes
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
)
out_bboxes
=
convert_bounding_box_format
(
out_bboxes
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
out_bboxes
=
out_bboxes
.
to
(
original_dtype
)
...
...
@@ -819,7 +819,7 @@ def _affine_bounding_boxes_with_expand(
def
affine_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
...
...
@@ -841,16 +841,16 @@ def affine_bounding_boxes(
return
out_box
@
_register_kernel_internal
(
affine
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
affine
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_affine_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
inpt
:
tv_tensor
s
.
BoundingBoxes
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
shear
:
List
[
float
],
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoint
s
.
BoundingBoxes
:
)
->
tv_tensor
s
.
BoundingBoxes
:
output
=
affine_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
...
...
@@ -861,7 +861,7 @@ def _affine_bounding_boxes_dispatch(
shear
=
shear
,
center
=
center
,
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
def
affine_mask
(
...
...
@@ -896,9 +896,9 @@ def affine_mask(
return
output
@
_register_kernel_internal
(
affine
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
affine
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_affine_mask_dispatch
(
inpt
:
datapoint
s
.
Mask
,
inpt
:
tv_tensor
s
.
Mask
,
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
scale
:
float
,
...
...
@@ -906,7 +906,7 @@ def _affine_mask_dispatch(
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoint
s
.
Mask
:
)
->
tv_tensor
s
.
Mask
:
output
=
affine_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
...
...
@@ -916,10 +916,10 @@ def _affine_mask_dispatch(
fill
=
fill
,
center
=
center
,
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
affine
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
affine
,
tv_tensor
s
.
Video
)
def
affine_video
(
video
:
torch
.
Tensor
,
angle
:
Union
[
int
,
float
],
...
...
@@ -961,7 +961,7 @@ def rotate(
@
_register_kernel_internal
(
rotate
,
torch
.
Tensor
)
@
_register_kernel_internal
(
rotate
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
rotate
,
tv_tensor
s
.
Image
)
def
rotate_image
(
image
:
torch
.
Tensor
,
angle
:
float
,
...
...
@@ -1027,7 +1027,7 @@ def _rotate_image_pil(
def
rotate_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
angle
:
float
,
expand
:
bool
=
False
,
...
...
@@ -1049,10 +1049,10 @@ def rotate_bounding_boxes(
)
@
_register_kernel_internal
(
rotate
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
rotate
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_rotate_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
)
->
tv_tensor
s
.
BoundingBoxes
:
output
,
canvas_size
=
rotate_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
...
...
@@ -1061,7 +1061,7 @@ def _rotate_bounding_boxes_dispatch(
expand
=
expand
,
center
=
center
,
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
def
rotate_mask
(
...
...
@@ -1092,20 +1092,20 @@ def rotate_mask(
return
output
@
_register_kernel_internal
(
rotate
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
rotate
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_rotate_mask_dispatch
(
inpt
:
datapoint
s
.
Mask
,
inpt
:
tv_tensor
s
.
Mask
,
angle
:
float
,
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
,
)
->
datapoint
s
.
Mask
:
)
->
tv_tensor
s
.
Mask
:
output
=
rotate_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
angle
=
angle
,
expand
=
expand
,
fill
=
fill
,
center
=
center
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
rotate
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
rotate
,
tv_tensor
s
.
Video
)
def
rotate_video
(
video
:
torch
.
Tensor
,
angle
:
float
,
...
...
@@ -1158,7 +1158,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
@
_register_kernel_internal
(
pad
,
torch
.
Tensor
)
@
_register_kernel_internal
(
pad
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
pad
,
tv_tensor
s
.
Image
)
def
pad_image
(
image
:
torch
.
Tensor
,
padding
:
List
[
int
],
...
...
@@ -1260,7 +1260,7 @@ def _pad_with_vector_fill(
_pad_image_pil
=
_register_kernel_internal
(
pad
,
PIL
.
Image
.
Image
)(
_FP
.
pad
)
@
_register_kernel_internal
(
pad
,
datapoint
s
.
Mask
)
@
_register_kernel_internal
(
pad
,
tv_tensor
s
.
Mask
)
def
pad_mask
(
mask
:
torch
.
Tensor
,
padding
:
List
[
int
],
...
...
@@ -1289,7 +1289,7 @@ def pad_mask(
def
pad_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
padding
:
List
[
int
],
padding_mode
:
str
=
"constant"
,
...
...
@@ -1300,7 +1300,7 @@ def pad_bounding_boxes(
left
,
right
,
top
,
bottom
=
_parse_pad_padding
(
padding
)
if
format
==
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
pad
=
[
left
,
top
,
left
,
top
]
else
:
pad
=
[
left
,
top
,
0
,
0
]
...
...
@@ -1314,10 +1314,10 @@ def pad_bounding_boxes(
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
),
canvas_size
@
_register_kernel_internal
(
pad
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
pad
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_pad_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
padding
:
List
[
int
],
padding_mode
:
str
=
"constant"
,
**
kwargs
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
padding
:
List
[
int
],
padding_mode
:
str
=
"constant"
,
**
kwargs
)
->
tv_tensor
s
.
BoundingBoxes
:
output
,
canvas_size
=
pad_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
...
...
@@ -1325,10 +1325,10 @@ def _pad_bounding_boxes_dispatch(
padding
=
padding
,
padding_mode
=
padding_mode
,
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
pad
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
pad
,
tv_tensor
s
.
Video
)
def
pad_video
(
video
:
torch
.
Tensor
,
padding
:
List
[
int
],
...
...
@@ -1350,7 +1350,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to
@
_register_kernel_internal
(
crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
crop
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
crop
,
tv_tensor
s
.
Image
)
def
crop_image
(
image
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
h
,
w
=
image
.
shape
[
-
2
:]
...
...
@@ -1375,7 +1375,7 @@ _register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil)
def
crop_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
top
:
int
,
left
:
int
,
height
:
int
,
...
...
@@ -1383,7 +1383,7 @@ def crop_bounding_boxes(
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
# Crop or implicit pad if left and/or top have negative values:
if
format
==
datapoint
s
.
BoundingBoxFormat
.
XYXY
:
if
format
==
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
:
sub
=
[
left
,
top
,
left
,
top
]
else
:
sub
=
[
left
,
top
,
0
,
0
]
...
...
@@ -1394,17 +1394,17 @@ def crop_bounding_boxes(
return
clamp_bounding_boxes
(
bounding_boxes
,
format
=
format
,
canvas_size
=
canvas_size
),
canvas_size
@
_register_kernel_internal
(
crop
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
crop
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_crop_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
tv_tensor
s
.
BoundingBoxes
:
output
,
canvas_size
=
crop_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
crop
,
datapoint
s
.
Mask
)
@
_register_kernel_internal
(
crop
,
tv_tensor
s
.
Mask
)
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -1420,7 +1420,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return
output
@
_register_kernel_internal
(
crop
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
crop
,
tv_tensor
s
.
Video
)
def
crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
return
crop_image
(
video
,
top
,
left
,
height
,
width
)
...
...
@@ -1505,7 +1505,7 @@ def _perspective_coefficients(
@
_register_kernel_internal
(
perspective
,
torch
.
Tensor
)
@
_register_kernel_internal
(
perspective
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
perspective
,
tv_tensor
s
.
Image
)
def
perspective_image
(
image
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1568,7 +1568,7 @@ def _perspective_image_pil(
def
perspective_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1582,7 +1582,7 @@ def perspective_bounding_boxes(
original_shape
=
bounding_boxes
.
shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
bounding_boxes
=
(
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
dtype
=
bounding_boxes
.
dtype
if
torch
.
is_floating_point
(
bounding_boxes
)
else
torch
.
float32
...
...
@@ -1649,25 +1649,25 @@ def perspective_bounding_boxes(
out_bboxes
=
clamp_bounding_boxes
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_boxes
.
dtype
),
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
)
# out_bboxes should be of shape [N boxes, 4]
return
convert_bounding_box_format
(
out_bboxes
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
@
_register_kernel_internal
(
perspective
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
perspective
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_perspective_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
inpt
:
tv_tensor
s
.
BoundingBoxes
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoint
s
.
BoundingBoxes
:
)
->
tv_tensor
s
.
BoundingBoxes
:
output
=
perspective_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
...
...
@@ -1676,7 +1676,7 @@ def _perspective_bounding_boxes_dispatch(
endpoints
=
endpoints
,
coefficients
=
coefficients
,
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
def
perspective_mask
(
...
...
@@ -1702,15 +1702,15 @@ def perspective_mask(
return
output
@
_register_kernel_internal
(
perspective
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
perspective
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_perspective_mask_dispatch
(
inpt
:
datapoint
s
.
Mask
,
inpt
:
tv_tensor
s
.
Mask
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
**
kwargs
,
)
->
datapoint
s
.
Mask
:
)
->
tv_tensor
s
.
Mask
:
output
=
perspective_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
startpoints
=
startpoints
,
...
...
@@ -1718,10 +1718,10 @@ def _perspective_mask_dispatch(
fill
=
fill
,
coefficients
=
coefficients
,
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
perspective
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
perspective
,
tv_tensor
s
.
Video
)
def
perspective_video
(
video
:
torch
.
Tensor
,
startpoints
:
Optional
[
List
[
List
[
int
]]],
...
...
@@ -1755,7 +1755,7 @@ elastic_transform = elastic
@
_register_kernel_internal
(
elastic
,
torch
.
Tensor
)
@
_register_kernel_internal
(
elastic
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
elastic
,
tv_tensor
s
.
Image
)
def
elastic_image
(
image
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
...
...
@@ -1841,7 +1841,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def
elastic_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
displacement
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
...
...
@@ -1864,7 +1864,7 @@ def elastic_bounding_boxes(
original_shape
=
bounding_boxes
.
shape
# TODO: first cast to float if bbox is int64 before convert_bounding_box_format
bounding_boxes
=
(
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
)
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
id_grid
=
_create_identity_grid
(
canvas_size
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -1887,23 +1887,23 @@ def elastic_bounding_boxes(
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bboxes
=
clamp_bounding_boxes
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_boxes
.
dtype
),
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
)
return
convert_bounding_box_format
(
out_bboxes
,
old_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
@
_register_kernel_internal
(
elastic
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
elastic
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_elastic_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
displacement
:
torch
.
Tensor
,
**
kwargs
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
displacement
:
torch
.
Tensor
,
**
kwargs
)
->
tv_tensor
s
.
BoundingBoxes
:
output
=
elastic_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
displacement
=
displacement
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
def
elastic_mask
(
...
...
@@ -1925,15 +1925,15 @@ def elastic_mask(
return
output
@
_register_kernel_internal
(
elastic
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
elastic
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_elastic_mask_dispatch
(
inpt
:
datapoint
s
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
)
->
datapoint
s
.
Mask
:
inpt
:
tv_tensor
s
.
Mask
,
displacement
:
torch
.
Tensor
,
fill
:
_FillTypeJIT
=
None
,
**
kwargs
)
->
tv_tensor
s
.
Mask
:
output
=
elastic_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
displacement
=
displacement
,
fill
=
fill
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
elastic
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
elastic
,
tv_tensor
s
.
Video
)
def
elastic_video
(
video
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
...
...
@@ -1982,7 +1982,7 @@ def _center_crop_compute_crop_anchor(
@
_register_kernel_internal
(
center_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
center_crop
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
center_crop
,
tv_tensor
s
.
Image
)
def
center_crop_image
(
image
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
crop_height
,
crop_width
=
_center_crop_parse_output_size
(
output_size
)
shape
=
image
.
shape
...
...
@@ -2021,7 +2021,7 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PI
def
center_crop_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
],
output_size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
...
...
@@ -2032,17 +2032,17 @@ def center_crop_bounding_boxes(
)
@
_register_kernel_internal
(
center_crop
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
center_crop
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_center_crop_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
output_size
:
List
[
int
]
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
output_size
:
List
[
int
]
)
->
tv_tensor
s
.
BoundingBoxes
:
output
,
canvas_size
=
center_crop_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
,
output_size
=
output_size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
@
_register_kernel_internal
(
center_crop
,
datapoint
s
.
Mask
)
@
_register_kernel_internal
(
center_crop
,
tv_tensor
s
.
Mask
)
def
center_crop_mask
(
mask
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
if
mask
.
ndim
<
3
:
mask
=
mask
.
unsqueeze
(
0
)
...
...
@@ -2058,7 +2058,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return
output
@
_register_kernel_internal
(
center_crop
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
center_crop
,
tv_tensor
s
.
Video
)
def
center_crop_video
(
video
:
torch
.
Tensor
,
output_size
:
List
[
int
])
->
torch
.
Tensor
:
return
center_crop_image
(
video
,
output_size
)
...
...
@@ -2102,7 +2102,7 @@ def resized_crop(
@
_register_kernel_internal
(
resized_crop
,
torch
.
Tensor
)
@
_register_kernel_internal
(
resized_crop
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
resized_crop
,
tv_tensor
s
.
Image
)
def
resized_crop_image
(
image
:
torch
.
Tensor
,
top
:
int
,
...
...
@@ -2156,7 +2156,7 @@ def _resized_crop_image_pil_dispatch(
def
resized_crop_bounding_boxes
(
bounding_boxes
:
torch
.
Tensor
,
format
:
datapoint
s
.
BoundingBoxFormat
,
format
:
tv_tensor
s
.
BoundingBoxFormat
,
top
:
int
,
left
:
int
,
height
:
int
,
...
...
@@ -2167,14 +2167,14 @@ def resized_crop_bounding_boxes(
return
resize_bounding_boxes
(
bounding_boxes
,
canvas_size
=
canvas_size
,
size
=
size
)
@
_register_kernel_internal
(
resized_crop
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
resized_crop
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
_resized_crop_bounding_boxes_dispatch
(
inpt
:
datapoint
s
.
BoundingBoxes
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
**
kwargs
)
->
datapoint
s
.
BoundingBoxes
:
inpt
:
tv_tensor
s
.
BoundingBoxes
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
**
kwargs
)
->
tv_tensor
s
.
BoundingBoxes
:
output
,
canvas_size
=
resized_crop_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
canvas_size
=
canvas_size
)
def
resized_crop_mask
(
...
...
@@ -2189,17 +2189,17 @@ def resized_crop_mask(
return
resize_mask
(
mask
,
size
)
@
_register_kernel_internal
(
resized_crop
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
resized_crop
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_resized_crop_mask_dispatch
(
inpt
:
datapoint
s
.
Mask
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
**
kwargs
)
->
datapoint
s
.
Mask
:
inpt
:
tv_tensor
s
.
Mask
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
,
size
:
List
[
int
],
**
kwargs
)
->
tv_tensor
s
.
Mask
:
output
=
resized_crop_mask
(
inpt
.
as_subclass
(
torch
.
Tensor
),
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
,
size
=
size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
@
_register_kernel_internal
(
resized_crop
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
resized_crop
,
tv_tensor
s
.
Video
)
def
resized_crop_video
(
video
:
torch
.
Tensor
,
top
:
int
,
...
...
@@ -2243,7 +2243,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
datapoint
s
.
Image
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
tv_tensor
s
.
Image
)
def
five_crop_image
(
image
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -2281,7 +2281,7 @@ def _five_crop_image_pil(
return
tl
,
tr
,
bl
,
br
,
center
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
datapoint
s
.
Video
)
@
_register_five_ten_crop_kernel_internal
(
five_crop
,
tv_tensor
s
.
Video
)
def
five_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -2313,7 +2313,7 @@ def ten_crop(
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
torch
.
Tensor
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
datapoint
s
.
Image
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
tv_tensor
s
.
Image
)
def
ten_crop_image
(
image
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
@@ -2367,7 +2367,7 @@ def _ten_crop_image_pil(
return
non_flipped
+
flipped
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
datapoint
s
.
Video
)
@
_register_five_ten_crop_kernel_internal
(
ten_crop
,
tv_tensor
s
.
Video
)
def
ten_crop_video
(
video
:
torch
.
Tensor
,
size
:
List
[
int
],
vertical_flip
:
bool
=
False
)
->
Tuple
[
...
...
torchvision/transforms/v2/functional/_meta.py
View file @
d5f4cc38
...
...
@@ -2,9 +2,9 @@ from typing import List, Optional, Tuple
import
PIL.Image
import
torch
from
torchvision
import
datapoints
from
torchvision.datapoints
import
BoundingBoxFormat
from
torchvision
import
tv_tensors
from
torchvision.transforms
import
_functional_pil
as
_FP
from
torchvision.tv_tensors
import
BoundingBoxFormat
from
torchvision.utils
import
_log_api_usage_once
...
...
@@ -22,7 +22,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]:
@
_register_kernel_internal
(
get_dimensions
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Image
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_dimensions
,
tv_tensors
.
Image
,
tv_tensor
_wrapper
=
False
)
def
get_dimensions_image
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
chw
=
list
(
image
.
shape
[
-
3
:])
ndims
=
len
(
chw
)
...
...
@@ -38,7 +38,7 @@ def get_dimensions_image(image: torch.Tensor) -> List[int]:
_get_dimensions_image_pil
=
_register_kernel_internal
(
get_dimensions
,
PIL
.
Image
.
Image
)(
_FP
.
get_dimensions
)
@
_register_kernel_internal
(
get_dimensions
,
datapoints
.
Video
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_dimensions
,
tv_tensors
.
Video
,
tv_tensor
_wrapper
=
False
)
def
get_dimensions_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_dimensions_image
(
video
)
...
...
@@ -54,7 +54,7 @@ def get_num_channels(inpt: torch.Tensor) -> int:
@
_register_kernel_internal
(
get_num_channels
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Image
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_num_channels
,
tv_tensors
.
Image
,
tv_tensor
_wrapper
=
False
)
def
get_num_channels_image
(
image
:
torch
.
Tensor
)
->
int
:
chw
=
image
.
shape
[
-
3
:]
ndims
=
len
(
chw
)
...
...
@@ -69,7 +69,7 @@ def get_num_channels_image(image: torch.Tensor) -> int:
_get_num_channels_image_pil
=
_register_kernel_internal
(
get_num_channels
,
PIL
.
Image
.
Image
)(
_FP
.
get_image_num_channels
)
@
_register_kernel_internal
(
get_num_channels
,
datapoints
.
Video
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_num_channels
,
tv_tensors
.
Video
,
tv_tensor
_wrapper
=
False
)
def
get_num_channels_video
(
video
:
torch
.
Tensor
)
->
int
:
return
get_num_channels_image
(
video
)
...
...
@@ -90,7 +90,7 @@ def get_size(inpt: torch.Tensor) -> List[int]:
@
_register_kernel_internal
(
get_size
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Image
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
tv_tensors
.
Image
,
tv_tensor
_wrapper
=
False
)
def
get_size_image
(
image
:
torch
.
Tensor
)
->
List
[
int
]:
hw
=
list
(
image
.
shape
[
-
2
:])
ndims
=
len
(
hw
)
...
...
@@ -106,18 +106,18 @@ def _get_size_image_pil(image: PIL.Image.Image) -> List[int]:
return
[
height
,
width
]
@
_register_kernel_internal
(
get_size
,
datapoints
.
Video
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
tv_tensors
.
Video
,
tv_tensor
_wrapper
=
False
)
def
get_size_video
(
video
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image
(
video
)
@
_register_kernel_internal
(
get_size
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_size
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
get_size_mask
(
mask
:
torch
.
Tensor
)
->
List
[
int
]:
return
get_size_image
(
mask
)
@
_register_kernel_internal
(
get_size
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
def
get_size_bounding_boxes
(
bounding_box
:
datapoint
s
.
BoundingBoxes
)
->
List
[
int
]:
@
_register_kernel_internal
(
get_size
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
def
get_size_bounding_boxes
(
bounding_box
:
tv_tensor
s
.
BoundingBoxes
)
->
List
[
int
]:
return
list
(
bounding_box
.
canvas_size
)
...
...
@@ -132,7 +132,7 @@ def get_num_frames(inpt: torch.Tensor) -> int:
@
_register_kernel_internal
(
get_num_frames
,
torch
.
Tensor
)
@
_register_kernel_internal
(
get_num_frames
,
datapoints
.
Video
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
get_num_frames
,
tv_tensors
.
Video
,
tv_tensor
_wrapper
=
False
)
def
get_num_frames_video
(
video
:
torch
.
Tensor
)
->
int
:
return
video
.
shape
[
-
4
]
...
...
@@ -205,7 +205,7 @@ def convert_bounding_box_format(
)
->
torch
.
Tensor
:
"""[BETA] See :func:`~torchvision.transforms.v2.ConvertBoundingBoxFormat` for details."""
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `
datapoint
s.BoundingBoxes` inputs. However, putting a default value on
# inputs as well as extract it from `
tv_tensor
s.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if
new_format
is
None
:
...
...
@@ -218,16 +218,16 @@ def convert_bounding_box_format(
if
old_format
is
None
:
raise
ValueError
(
"For pure tensor inputs, `old_format` has to be passed."
)
return
_convert_bounding_box_format
(
inpt
,
old_format
=
old_format
,
new_format
=
new_format
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoint
s
.
BoundingBoxes
):
elif
isinstance
(
inpt
,
tv_tensor
s
.
BoundingBoxes
):
if
old_format
is
not
None
:
raise
ValueError
(
"For bounding box
datapoint
inputs, `old_format` must not be passed."
)
raise
ValueError
(
"For bounding box
tv_tensor
inputs, `old_format` must not be passed."
)
output
=
_convert_bounding_box_format
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_format
=
inpt
.
format
,
new_format
=
new_format
,
inplace
=
inplace
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
,
format
=
new_format
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
,
format
=
new_format
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a bounding box
datapoint
, but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor or a bounding box
tv_tensor
, but got
{
type
(
inpt
)
}
instead."
)
...
...
@@ -239,7 +239,7 @@ def _clamp_bounding_boxes(
in_dtype
=
bounding_boxes
.
dtype
bounding_boxes
=
bounding_boxes
.
clone
()
if
bounding_boxes
.
is_floating_point
()
else
bounding_boxes
.
float
()
xyxy_boxes
=
convert_bounding_box_format
(
bounding_boxes
,
old_format
=
format
,
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bounding_boxes
,
old_format
=
format
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
canvas_size
[
1
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
canvas_size
[
0
])
...
...
@@ -263,12 +263,12 @@ def clamp_bounding_boxes(
if
format
is
None
or
canvas_size
is
None
:
raise
ValueError
(
"For pure tensor inputs, `format` and `canvas_size` has to be passed."
)
return
_clamp_bounding_boxes
(
inpt
,
format
=
format
,
canvas_size
=
canvas_size
)
elif
isinstance
(
inpt
,
datapoint
s
.
BoundingBoxes
):
elif
isinstance
(
inpt
,
tv_tensor
s
.
BoundingBoxes
):
if
format
is
not
None
or
canvas_size
is
not
None
:
raise
ValueError
(
"For bounding box
datapoint
inputs, `format` and `canvas_size` must not be passed."
)
raise
ValueError
(
"For bounding box
tv_tensor
inputs, `format` and `canvas_size` must not be passed."
)
output
=
_clamp_bounding_boxes
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
canvas_size
=
inpt
.
canvas_size
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a bounding box
datapoint
, but got
{
type
(
inpt
)
}
instead."
f
"Input can either be a plain tensor or a bounding box
tv_tensor
, but got
{
type
(
inpt
)
}
instead."
)
torchvision/transforms/v2/functional/_misc.py
View file @
d5f4cc38
...
...
@@ -5,7 +5,7 @@ import PIL.Image
import
torch
from
torch.nn.functional
import
conv2d
,
pad
as
torch_pad
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms._functional_tensor
import
_max_value
from
torchvision.transforms.functional
import
pil_to_tensor
,
to_pil_image
...
...
@@ -31,7 +31,7 @@ def normalize(
@
_register_kernel_internal
(
normalize
,
torch
.
Tensor
)
@
_register_kernel_internal
(
normalize
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
normalize
,
tv_tensor
s
.
Image
)
def
normalize_image
(
image
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
image
.
is_floating_point
():
raise
TypeError
(
f
"Input tensor should be a float tensor. Got
{
image
.
dtype
}
."
)
...
...
@@ -65,7 +65,7 @@ def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], in
return
image
.
div_
(
std
)
@
_register_kernel_internal
(
normalize
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
normalize
,
tv_tensor
s
.
Video
)
def
normalize_video
(
video
:
torch
.
Tensor
,
mean
:
List
[
float
],
std
:
List
[
float
],
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
return
normalize_image
(
video
,
mean
,
std
,
inplace
=
inplace
)
...
...
@@ -98,7 +98,7 @@ def _get_gaussian_kernel2d(
@
_register_kernel_internal
(
gaussian_blur
,
torch
.
Tensor
)
@
_register_kernel_internal
(
gaussian_blur
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
gaussian_blur
,
tv_tensor
s
.
Image
)
def
gaussian_blur_image
(
image
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -172,7 +172,7 @@ def _gaussian_blur_image_pil(
return
to_pil_image
(
output
,
mode
=
image
.
mode
)
@
_register_kernel_internal
(
gaussian_blur
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
gaussian_blur
,
tv_tensor
s
.
Video
)
def
gaussian_blur_video
(
video
:
torch
.
Tensor
,
kernel_size
:
List
[
int
],
sigma
:
Optional
[
List
[
float
]]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -206,7 +206,7 @@ def _num_value_bits(dtype: torch.dtype) -> int:
@
_register_kernel_internal
(
to_dtype
,
torch
.
Tensor
)
@
_register_kernel_internal
(
to_dtype
,
datapoint
s
.
Image
)
@
_register_kernel_internal
(
to_dtype
,
tv_tensor
s
.
Image
)
def
to_dtype_image
(
image
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
if
image
.
dtype
==
dtype
:
...
...
@@ -265,13 +265,13 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32)
return
to_dtype_image
(
image
,
dtype
=
dtype
,
scale
=
True
)
@
_register_kernel_internal
(
to_dtype
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
to_dtype
,
tv_tensor
s
.
Video
)
def
to_dtype_video
(
video
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
return
to_dtype_image
(
video
,
dtype
,
scale
=
scale
)
@
_register_kernel_internal
(
to_dtype
,
datapoint
s
.
BoundingBoxes
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
datapoints
.
Mask
,
datapoint
_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
tv_tensor
s
.
BoundingBoxes
,
tv_tensor
_wrapper
=
False
)
@
_register_kernel_internal
(
to_dtype
,
tv_tensors
.
Mask
,
tv_tensor
_wrapper
=
False
)
def
_to_dtype_tensor_dispatch
(
inpt
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
scale
:
bool
=
False
)
->
torch
.
Tensor
:
# We don't need to unwrap and rewrap here, since
Datapoint
.to() preserves the type
# We don't need to unwrap and rewrap here, since
TVTensor
.to() preserves the type
return
inpt
.
to
(
dtype
)
torchvision/transforms/v2/functional/_temporal.py
View file @
d5f4cc38
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.utils
import
_log_api_usage_once
...
...
@@ -19,7 +19,7 @@ def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Te
@
_register_kernel_internal
(
uniform_temporal_subsample
,
torch
.
Tensor
)
@
_register_kernel_internal
(
uniform_temporal_subsample
,
datapoint
s
.
Video
)
@
_register_kernel_internal
(
uniform_temporal_subsample
,
tv_tensor
s
.
Video
)
def
uniform_temporal_subsample_video
(
video
:
torch
.
Tensor
,
num_samples
:
int
)
->
torch
.
Tensor
:
# Reference: https://github.com/facebookresearch/pytorchvideo/blob/a0a131e/pytorchvideo/transforms/functional.py#L19
t_max
=
video
.
shape
[
-
4
]
-
1
...
...
torchvision/transforms/v2/functional/_type_conversion.py
View file @
d5f4cc38
...
...
@@ -3,12 +3,12 @@ from typing import Union
import
numpy
as
np
import
PIL.Image
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
from
torchvision.transforms
import
functional
as
_F
@
torch
.
jit
.
unused
def
to_image
(
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
])
->
datapoint
s
.
Image
:
def
to_image
(
inpt
:
Union
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
np
.
ndarray
])
->
tv_tensor
s
.
Image
:
"""[BETA] See :class:`~torchvision.transforms.v2.ToImage` for details."""
if
isinstance
(
inpt
,
np
.
ndarray
):
output
=
torch
.
from_numpy
(
inpt
).
permute
((
2
,
0
,
1
)).
contiguous
()
...
...
@@ -18,7 +18,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoin
output
=
inpt
else
:
raise
TypeError
(
f
"Input can either be a numpy array or a PIL image, but got
{
type
(
inpt
)
}
instead."
)
return
datapoint
s
.
Image
(
output
)
return
tv_tensor
s
.
Image
(
output
)
to_pil_image
=
_F
.
to_pil_image
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
d5f4cc38
...
...
@@ -2,21 +2,21 @@ import functools
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
import
torch
from
torchvision
import
datapoint
s
from
torchvision
import
tv_tensor
s
_FillType
=
Union
[
int
,
float
,
Sequence
[
int
],
Sequence
[
float
],
None
]
_FillTypeJIT
=
Optional
[
List
[
float
]]
def
is_pure_tensor
(
inpt
:
Any
)
->
bool
:
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
datapoints
.
Datapoint
)
return
isinstance
(
inpt
,
torch
.
Tensor
)
and
not
isinstance
(
inpt
,
tv_tensors
.
TVTensor
)
# {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY
:
Dict
[
Callable
,
Dict
[
Type
,
Callable
]]
=
{}
def
_kernel_
datapoint
_wrapper
(
kernel
):
def
_kernel_
tv_tensor
_wrapper
(
kernel
):
@
functools
.
wraps
(
kernel
)
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
# If you're wondering whether we could / should get rid of this wrapper,
...
...
@@ -25,24 +25,24 @@ def _kernel_datapoint_wrapper(kernel):
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap(), because the
Datapoint
type would be
# this wrapper to call wrap(), because the
TVTensor
type would be
# lost after the first operation due to our own __torch_function__
# logic.
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
return
datapoint
s
.
wrap
(
output
,
like
=
inpt
)
return
tv_tensor
s
.
wrap
(
output
,
like
=
inpt
)
return
wrapper
def
_register_kernel_internal
(
functional
,
input_type
,
*
,
datapoint
_wrapper
=
True
):
def
_register_kernel_internal
(
functional
,
input_type
,
*
,
tv_tensor
_wrapper
=
True
):
registry
=
_KERNEL_REGISTRY
.
setdefault
(
functional
,
{})
if
input_type
in
registry
:
raise
ValueError
(
f
"Functional
{
functional
}
already has a kernel registered for type
{
input_type
}
."
)
def
decorator
(
kernel
):
registry
[
input_type
]
=
(
_kernel_
datapoint
_wrapper
(
kernel
)
if
issubclass
(
input_type
,
datapoints
.
Datapoint
)
and
datapoint
_wrapper
_kernel_
tv_tensor
_wrapper
(
kernel
)
if
issubclass
(
input_type
,
tv_tensors
.
TVTensor
)
and
tv_tensor
_wrapper
else
kernel
)
return
kernel
...
...
@@ -62,14 +62,14 @@ def _name_to_functional(name):
_BUILTIN_DATAPOINT_TYPES
=
{
obj
for
obj
in
datapoint
s
.
__dict__
.
values
()
if
isinstance
(
obj
,
type
)
and
issubclass
(
obj
,
datapoints
.
Datapoint
)
obj
for
obj
in
tv_tensor
s
.
__dict__
.
values
()
if
isinstance
(
obj
,
type
)
and
issubclass
(
obj
,
tv_tensors
.
TVTensor
)
}
def
register_kernel
(
functional
,
datapoint
_cls
):
"""[BETA] Decorate a kernel to register it for a functional and a (custom)
datapoint
type.
def
register_kernel
(
functional
,
tv_tensor
_cls
):
"""[BETA] Decorate a kernel to register it for a functional and a (custom)
tv_tensor
type.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_
datapoint
s.py` for usage
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_
tv_tensor
s.py` for usage
details.
"""
if
isinstance
(
functional
,
str
):
...
...
@@ -83,16 +83,16 @@ def register_kernel(functional, datapoint_cls):
f
"but got
{
functional
}
."
)
if
not
(
isinstance
(
datapoint
_cls
,
type
)
and
issubclass
(
datapoint_cls
,
datapoints
.
Datapoint
)):
if
not
(
isinstance
(
tv_tensor
_cls
,
type
)
and
issubclass
(
tv_tensor_cls
,
tv_tensors
.
TVTensor
)):
raise
ValueError
(
f
"Kernels can only be registered for subclasses of torchvision.
datapoints.Datapoint
, "
f
"but got
{
datapoint
_cls
}
."
f
"Kernels can only be registered for subclasses of torchvision.
tv_tensors.TVTensor
, "
f
"but got
{
tv_tensor
_cls
}
."
)
if
datapoint
_cls
in
_BUILTIN_DATAPOINT_TYPES
:
raise
ValueError
(
f
"Kernels cannot be registered for the builtin
datapoint
classes, but got
{
datapoint
_cls
}
"
)
if
tv_tensor
_cls
in
_BUILTIN_DATAPOINT_TYPES
:
raise
ValueError
(
f
"Kernels cannot be registered for the builtin
tv_tensor
classes, but got
{
tv_tensor
_cls
}
"
)
return
_register_kernel_internal
(
functional
,
datapoint_cls
,
datapoint
_wrapper
=
False
)
return
_register_kernel_internal
(
functional
,
tv_tensor_cls
,
tv_tensor
_wrapper
=
False
)
def
_get_kernel
(
functional
,
input_type
,
*
,
allow_passthrough
=
False
):
...
...
@@ -103,10 +103,10 @@ def _get_kernel(functional, input_type, *, allow_passthrough=False):
for
cls
in
input_type
.
__mro__
:
if
cls
in
registry
:
return
registry
[
cls
]
elif
cls
is
datapoints
.
Datapoint
:
# We don't want user-defined
datapoint
s to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at
datapoints.Datapoint
, since we don't
# allow kernels to be registered for
datapoints.Datapoint
anyway.
elif
cls
is
tv_tensors
.
TVTensor
:
# We don't want user-defined
tv_tensor
s to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at
tv_tensors.TVTensor
, since we don't
# allow kernels to be registered for
tv_tensors.TVTensor
anyway.
break
if
allow_passthrough
:
...
...
@@ -130,12 +130,12 @@ def _register_five_ten_crop_kernel_internal(functional, input_type):
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
output
=
kernel
(
inpt
,
*
args
,
**
kwargs
)
container_type
=
type
(
output
)
return
container_type
(
datapoint
s
.
wrap
(
o
,
like
=
inpt
)
for
o
in
output
)
return
container_type
(
tv_tensor
s
.
wrap
(
o
,
like
=
inpt
)
for
o
in
output
)
return
wrapper
def
decorator
(
kernel
):
registry
[
input_type
]
=
wrap
(
kernel
)
if
issubclass
(
input_type
,
datapoints
.
Datapoint
)
else
kernel
registry
[
input_type
]
=
wrap
(
kernel
)
if
issubclass
(
input_type
,
tv_tensors
.
TVTensor
)
else
kernel
return
kernel
return
decorator
torchvision/
datapoint
s/__init__.py
→
torchvision/
tv_tensor
s/__init__.py
View file @
d5f4cc38
import
torch
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._datapoint
import
Datapoint
from
._image
import
Image
from
._mask
import
Mask
from
._torch_function_helpers
import
set_return_type
from
._tv_tensor
import
TVTensor
from
._video
import
Video
def
wrap
(
wrappee
,
*
,
like
,
**
kwargs
):
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.
datapoints.Datapoint
` subclass as ``like``.
"""[BETA] Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.
tv_tensors.TVTensor
` subclass as ``like``.
If ``like`` is a :class:`~torchvision.
datapoint
s.BoundingBoxes`, the ``format`` and ``canvas_size`` of
If ``like`` is a :class:`~torchvision.
tv_tensor
s.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
Args:
wrappee (Tensor): The tensor to convert.
like (:class:`~torchvision.
datapoints.Datapoint
`): The reference.
like (:class:`~torchvision.
tv_tensors.TVTensor
`): The reference.
``wrappee`` will be converted into the same subclass as ``like``.
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.
datapoint
.BoundingBoxes`.
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.
tv_tensor
.BoundingBoxes`.
Ignored otherwise.
"""
if
isinstance
(
like
,
BoundingBoxes
):
...
...
torchvision/
datapoint
s/_bounding_box.py
→
torchvision/
tv_tensor
s/_bounding_box.py
View file @
d5f4cc38
...
...
@@ -6,7 +6,7 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import
torch
from
torch.utils._pytree
import
tree_flatten
from
._
datapoint
import
Datapoint
from
._
tv_tensor
import
TVTensor
class
BoundingBoxFormat
(
Enum
):
...
...
@@ -24,13 +24,13 @@ class BoundingBoxFormat(Enum):
CXCYWH
=
"CXCYWH"
class
BoundingBoxes
(
Datapoint
):
class
BoundingBoxes
(
TVTensor
):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
.. note::
There should be only one :class:`~torchvision.
datapoint
s.BoundingBoxes`
There should be only one :class:`~torchvision.
tv_tensor
s.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.
datapoint
s.BoundingBoxes` object can
although one :class:`~torchvision.
tv_tensor
s.BoundingBoxes` object can
contain multiple bounding boxes.
Args:
...
...
torchvision/
datapoint
s/_dataset_wrapper.py
→
torchvision/
tv_tensor
s/_dataset_wrapper.py
View file @
d5f4cc38
...
...
@@ -9,7 +9,7 @@ from collections import defaultdict
import
torch
from
torchvision
import
data
points
,
dataset
s
from
torchvision
import
data
sets
,
tv_tensor
s
from
torchvision.transforms.v2
import
functional
as
F
__all__
=
[
"wrap_dataset_for_transforms_v2"
]
...
...
@@ -36,26 +36,26 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
* :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper
returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format),
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.
datapoint
s``.
``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.
tv_tensor
s``.
The original keys are preserved. If ``target_keys`` is omitted, returns only the values for the
``"image_id"``, ``"boxes"``, and ``"labels"``.
* :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to
the target and wrap the data in the corresponding ``torchvision.
datapoint
s``. The original keys are
the target and wrap the data in the corresponding ``torchvision.
tv_tensor
s``. The original keys are
preserved. If ``target_keys`` is omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.
datapoint
s.BoundingBoxes`
datapoint
.
coordinate format and wrapped into a :class:`~torchvision.
tv_tensor
s.BoundingBoxes`
tv_tensor
.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.
datapoint
s``. The original keys are preserved. If ``target_keys`` is
in the corresponding ``torchvision.
tv_tensor
s``. The original keys are preserved. If ``target_keys`` is
omitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a
:class:`~torchvision.
datapoints.Mask` datapoint
.
:class:`~torchvision.
tv_tensors.Mask` tv_tensor
.
* :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a
:class:`~torchvision.
datapoints.Mask` datapoint
. The target for ``target_type="instance"`` is *replaced* by
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.
datapoints.Mask` datapoint
) and
:class:`~torchvision.
tv_tensors.Mask` tv_tensor
. The target for ``target_type="instance"`` is *replaced* by
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.
tv_tensors.Mask` tv_tensor
) and
``"labels"``.
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.
datapoint
s.BoundingBoxes`
datapoint
.
coordinate format and wrapped into a :class:`~torchvision.
tv_tensor
s.BoundingBoxes`
tv_tensor
.
Image classification datasets
...
...
@@ -66,13 +66,13 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of
:class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the
segmentation mask into a :class:`~torchvision.
datapoint
s.Mask` (second item).
segmentation mask into a :class:`~torchvision.
tv_tensor
s.Mask` (second item).
Video classification datasets
Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a
:class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a
:class:`~torchvision.
datapoint
s.Video` while leaving the other items as is.
:class:`~torchvision.
tv_tensor
s.Video` while leaving the other items as is.
.. note::
...
...
@@ -98,12 +98,12 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
)
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
# "WrappedImageNet" at runtime that doubly inherits from VisionDataset
Datapoint
Wrapper (see below) as well as the
# "WrappedImageNet" at runtime that doubly inherits from VisionDataset
TVTensor
Wrapper (see below) as well as the
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
# while we can still inject everything that we need.
wrapped_dataset_cls
=
type
(
f
"Wrapped
{
type
(
dataset
).
__name__
}
"
,
(
VisionDataset
Datapoint
Wrapper
,
type
(
dataset
)),
{})
# Since VisionDataset
Datapoint
Wrapper comes before ImageNet in the MRO, calling the class hits
# VisionDataset
Datapoint
Wrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
wrapped_dataset_cls
=
type
(
f
"Wrapped
{
type
(
dataset
).
__name__
}
"
,
(
VisionDataset
TVTensor
Wrapper
,
type
(
dataset
)),
{})
# Since VisionDataset
TVTensor
Wrapper comes before ImageNet in the MRO, calling the class hits
# VisionDataset
TVTensor
Wrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
# have the existing instance as attribute on the new object.
return
wrapped_dataset_cls
(
dataset
,
target_keys
)
...
...
@@ -125,7 +125,7 @@ class WrapperFactories(dict):
WRAPPER_FACTORIES
=
WrapperFactories
()
class
VisionDataset
Datapoint
Wrapper
:
class
VisionDataset
TVTensor
Wrapper
:
def
__init__
(
self
,
dataset
,
target_keys
):
dataset_cls
=
type
(
dataset
)
...
...
@@ -134,7 +134,7 @@ class VisionDatasetDatapointWrapper:
f
"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
f
"but got a '
{
dataset_cls
.
__name__
}
' instead.
\n
"
f
"For an example of how to perform the wrapping for custom datasets, see
\n\n
"
"https://pytorch.org/vision/main/auto_examples/plot_
datapoint
s.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
"https://pytorch.org/vision/main/auto_examples/plot_
tv_tensor
s.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
)
for
cls
in
dataset_cls
.
mro
():
...
...
@@ -221,7 +221,7 @@ def identity_wrapper_factory(dataset, target_keys):
def
pil_image_to_mask
(
pil_image
):
return
datapoint
s
.
Mask
(
pil_image
)
return
tv_tensor
s
.
Mask
(
pil_image
)
def
parse_target_keys
(
target_keys
,
*
,
available
,
default
):
...
...
@@ -302,7 +302,7 @@ def video_classification_wrapper_factory(dataset, target_keys):
def
wrapper
(
idx
,
sample
):
video
,
audio
,
label
=
sample
video
=
datapoint
s
.
Video
(
video
)
video
=
tv_tensor
s
.
Video
(
video
)
return
video
,
audio
,
label
...
...
@@ -373,16 +373,16 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
F
.
convert_bounding_box_format
(
datapoint
s
.
BoundingBoxes
(
tv_tensor
s
.
BoundingBoxes
(
batched_target
[
"bbox"
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
canvas_size
,
),
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
)
if
"masks"
in
target_keys
:
target
[
"masks"
]
=
datapoint
s
.
Mask
(
target
[
"masks"
]
=
tv_tensor
s
.
Mask
(
torch
.
stack
(
[
segmentation_to_mask
(
segmentation
,
canvas_size
=
canvas_size
)
...
...
@@ -454,12 +454,12 @@ def voc_detection_wrapper_factory(dataset, target_keys):
target
=
{}
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
datapoint
s
.
BoundingBoxes
(
target
[
"boxes"
]
=
tv_tensor
s
.
BoundingBoxes
(
[
[
int
(
bndbox
[
part
])
for
part
in
(
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
)]
for
bndbox
in
batched_instances
[
"bndbox"
]
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
image
.
height
,
image
.
width
),
)
...
...
@@ -494,12 +494,12 @@ def celeba_wrapper_factory(dataset, target_keys):
target_types
=
dataset
.
target_type
,
type_wrappers
=
{
"bbox"
:
lambda
item
:
F
.
convert_bounding_box_format
(
datapoint
s
.
BoundingBoxes
(
tv_tensor
s
.
BoundingBoxes
(
item
,
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
(
image
.
height
,
image
.
width
),
),
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
),
},
)
...
...
@@ -544,9 +544,9 @@ def kitti_wrapper_factory(dataset, target_keys):
target
=
{}
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
datapoint
s
.
BoundingBoxes
(
target
[
"boxes"
]
=
tv_tensor
s
.
BoundingBoxes
(
batched_target
[
"bbox"
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
image
.
height
,
image
.
width
),
)
...
...
@@ -596,7 +596,7 @@ def cityscapes_wrapper_factory(dataset, target_keys):
if
label
>=
1_000
:
label
//=
1_000
labels
.
append
(
label
)
return
dict
(
masks
=
datapoint
s
.
Mask
(
torch
.
stack
(
masks
)),
labels
=
torch
.
stack
(
labels
))
return
dict
(
masks
=
tv_tensor
s
.
Mask
(
torch
.
stack
(
masks
)),
labels
=
torch
.
stack
(
labels
))
def
wrapper
(
idx
,
sample
):
image
,
target
=
sample
...
...
@@ -641,10 +641,10 @@ def widerface_wrapper(dataset, target_keys):
if
"bbox"
in
target_keys
:
target
[
"bbox"
]
=
F
.
convert_bounding_box_format
(
datapoint
s
.
BoundingBoxes
(
target
[
"bbox"
],
format
=
datapoint
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
(
image
.
height
,
image
.
width
)
tv_tensor
s
.
BoundingBoxes
(
target
[
"bbox"
],
format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYWH
,
canvas_size
=
(
image
.
height
,
image
.
width
)
),
new_format
=
datapoint
s
.
BoundingBoxFormat
.
XYXY
,
new_format
=
tv_tensor
s
.
BoundingBoxFormat
.
XYXY
,
)
return
image
,
target
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment