Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
332bff93
Unverified
Commit
332bff93
authored
Jul 31, 2023
by
Nicolas Hug
Committed by
GitHub
Jul 31, 2023
Browse files
Renaming: `BoundingBox` -> `BoundingBoxes` (#7778)
parent
d4e5aa21
Changes
39
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
405 additions
and
405 deletions
+405
-405
docs/source/datapoints.rst
docs/source/datapoints.rst
+1
-1
docs/source/transforms.rst
docs/source/transforms.rst
+2
-2
gallery/plot_datapoints.py
gallery/plot_datapoints.py
+6
-6
gallery/plot_transforms_v2.py
gallery/plot_transforms_v2.py
+1
-1
gallery/plot_transforms_v2_e2e.py
gallery/plot_transforms_v2_e2e.py
+2
-2
references/detection/presets.py
references/detection/presets.py
+1
-1
test/common_utils.py
test/common_utils.py
+9
-9
test/test_datapoints.py
test/test_datapoints.py
+2
-2
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+14
-14
test/test_transforms_v2.py
test/test_transforms_v2.py
+39
-39
test/test_transforms_v2_consistency.py
test/test_transforms_v2_consistency.py
+1
-1
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+43
-43
test/test_transforms_v2_refactored.py
test/test_transforms_v2_refactored.py
+113
-113
test/test_transforms_v2_utils.py
test/test_transforms_v2_utils.py
+18
-18
test/transforms_v2_dispatcher_infos.py
test/transforms_v2_dispatcher_infos.py
+10
-10
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+85
-85
torchvision/datapoints/__init__.py
torchvision/datapoints/__init__.py
+1
-1
torchvision/datapoints/_bounding_box.py
torchvision/datapoints/_bounding_box.py
+45
-45
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+2
-2
torchvision/datapoints/_dataset_wrapper.py
torchvision/datapoints/_dataset_wrapper.py
+10
-10
No files found.
docs/source/datapoints.rst
View file @
332bff93
...
...
@@ -15,5 +15,5 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
Image
Video
BoundingBoxFormat
BoundingBox
BoundingBox
es
Mask
docs/source/transforms.rst
View file @
332bff93
...
...
@@ -206,8 +206,8 @@ Miscellaneous
v2.RandomErasing
Lambda
v2.Lambda
v2.SanitizeBoundingBox
v2.ClampBoundingBox
v2.SanitizeBoundingBox
es
v2.ClampBoundingBox
es
v2.UniformTemporalSubsample
.. _conversion_transforms:
...
...
gallery/plot_datapoints.py
View file @
332bff93
...
...
@@ -47,7 +47,7 @@ assert image.data_ptr() == tensor.data_ptr()
#
# * :class:`~torchvision.datapoints.Image`
# * :class:`~torchvision.datapoints.Video`
# * :class:`~torchvision.datapoints.BoundingBox`
# * :class:`~torchvision.datapoints.BoundingBox
es
`
# * :class:`~torchvision.datapoints.Mask`
#
# How do I construct a datapoint?
...
...
@@ -76,10 +76,10 @@ print(image.shape, image.dtype)
########################################################################################################################
# In general, the datapoints can also store additional metadata that complements the underlying tensor. For example,
# :class:`~torchvision.datapoints.BoundingBox` stores the coordinate format as well as the spatial size of the
# :class:`~torchvision.datapoints.BoundingBox
es
` stores the coordinate format as well as the spatial size of the
# corresponding image alongside the actual values:
bounding_box
=
datapoints
.
BoundingBox
(
bounding_box
=
datapoints
.
BoundingBox
es
(
[
17
,
16
,
344
,
495
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
image
.
shape
[
-
2
:]
)
print
(
bounding_box
)
...
...
@@ -105,7 +105,7 @@ class PennFudanDataset(torch.utils.data.Dataset):
def
__getitem__
(
self
,
item
):
...
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
datapoints
.
BoundingBox
es
(
boxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
F
.
get_spatial_size
(
img
),
...
...
@@ -126,7 +126,7 @@ class PennFudanDataset(torch.utils.data.Dataset):
class
WrapPennFudanDataset
:
def
__call__
(
self
,
img
,
target
):
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
datapoints
.
BoundingBox
es
(
target
[
"boxes"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
F
.
get_spatial_size
(
img
),
...
...
@@ -147,7 +147,7 @@ def get_transform(train):
########################################################################################################################
# .. note::
#
# If both :class:`~torchvision.datapoints.BoundingBox`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# If both :class:`~torchvision.datapoints.BoundingBox
es
`'es and :class:`~torchvision.datapoints.Mask`'s are included in
# the sample, ``torchvision.transforms.v2`` will transform them both. Meaning, if you don't need both, dropping or
# at least not wrapping the obsolete parts, can lead to a significant performance boost.
#
...
...
gallery/plot_transforms_v2.py
View file @
332bff93
...
...
@@ -29,7 +29,7 @@ def load_data():
masks
=
datapoints
.
Mask
(
merged_masks
==
labels
.
view
(
-
1
,
1
,
1
))
bounding_boxes
=
datapoints
.
BoundingBox
(
bounding_boxes
=
datapoints
.
BoundingBox
es
(
masks_to_boxes
(
masks
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
image
.
shape
[
-
2
:]
)
...
...
gallery/plot_transforms_v2_e2e.py
View file @
332bff93
...
...
@@ -106,13 +106,13 @@ transform = transforms.Compose(
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToImageTensor
(),
transforms
.
ConvertImageDtype
(
torch
.
float32
),
transforms
.
SanitizeBoundingBox
(),
transforms
.
SanitizeBoundingBox
es
(),
]
)
########################################################################################################################
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox
es
` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
...
...
references/detection/presets.py
View file @
332bff93
...
...
@@ -78,7 +78,7 @@ class DetectionPresetTrain:
if
use_v2
:
transforms
+=
[
T
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
XYXY
),
T
.
SanitizeBoundingBox
(),
T
.
SanitizeBoundingBox
es
(),
]
self
.
transforms
=
T
.
Compose
(
transforms
)
...
...
test/common_utils.py
View file @
332bff93
...
...
@@ -620,7 +620,7 @@ def make_image_loaders_for_interpolation(
@
dataclasses
.
dataclass
class
BoundingBoxLoader
(
TensorLoader
):
class
BoundingBox
es
Loader
(
TensorLoader
):
format
:
datapoints
.
BoundingBoxFormat
spatial_size
:
Tuple
[
int
,
int
]
...
...
@@ -639,7 +639,7 @@ def make_bounding_box(
- (box[3] - box[1], box[2] - box[0]) for XYXY
- (H, W) for XYWH and CXCYWH
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
returned datapoints.BoundingBox
returned datapoints.BoundingBox
es
To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
functions, e.g.
...
...
@@ -647,8 +647,8 @@ def make_bounding_box(
.. code::
image = make_image=(size=size)
bounding_box = make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_box) == F.get_spatial_size(image)
bounding_box
es
= make_bounding_box(spatial_size=size)
assert F.get_spatial_size(bounding_box
es
) == F.get_spatial_size(image)
For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
other maker functions, e.g.
...
...
@@ -656,8 +656,8 @@ def make_bounding_box(
.. code::
image = make_image=()
bounding_box = make_bounding_box()
assert F.get_spatial_size(bounding_box) == F.get_spatial_size(image)
bounding_box
es
= make_bounding_box()
assert F.get_spatial_size(bounding_box
es
) == F.get_spatial_size(image)
"""
def
sample_position
(
values
,
max_value
):
...
...
@@ -679,7 +679,7 @@ def make_bounding_box(
dtype
=
dtype
or
torch
.
float32
if
any
(
dim
==
0
for
dim
in
batch_dims
):
return
datapoints
.
BoundingBox
(
return
datapoints
.
BoundingBox
es
(
torch
.
empty
(
*
batch_dims
,
4
,
dtype
=
dtype
,
device
=
device
),
format
=
format
,
spatial_size
=
spatial_size
)
...
...
@@ -705,7 +705,7 @@ def make_bounding_box(
else
:
raise
ValueError
(
f
"Format
{
format
}
is not supported"
)
return
datapoints
.
BoundingBox
(
return
datapoints
.
BoundingBox
es
(
torch
.
stack
(
parts
,
dim
=-
1
).
to
(
dtype
=
dtype
,
device
=
device
),
format
=
format
,
spatial_size
=
spatial_size
)
...
...
@@ -725,7 +725,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
format
=
format
,
spatial_size
=
spatial_size
,
batch_dims
=
batch_dims
,
dtype
=
dtype
,
device
=
device
)
return
BoundingBoxLoader
(
fn
,
shape
=
(
*
extra_dims
,
4
),
dtype
=
dtype
,
format
=
format
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
Loader
(
fn
,
shape
=
(
*
extra_dims
,
4
),
dtype
=
dtype
,
format
=
format
,
spatial_size
=
spatial_size
)
def
make_bounding_box_loaders
(
...
...
test/test_datapoints.py
View file @
332bff93
...
...
@@ -27,7 +27,7 @@ def test_mask_instance(data):
"format"
,
[
"XYXY"
,
"CXCYWH"
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
datapoints
.
BoundingBoxFormat
.
XYWH
]
)
def
test_bbox_instance
(
data
,
format
):
bboxes
=
datapoints
.
BoundingBox
(
data
,
format
=
format
,
spatial_size
=
(
32
,
32
))
bboxes
=
datapoints
.
BoundingBox
es
(
data
,
format
=
format
,
spatial_size
=
(
32
,
32
))
assert
isinstance
(
bboxes
,
torch
.
Tensor
)
assert
bboxes
.
ndim
==
2
and
bboxes
.
shape
[
1
]
==
4
if
isinstance
(
format
,
str
):
...
...
@@ -164,7 +164,7 @@ def test_wrap_like():
[
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
)),
datapoints
.
Video
(
torch
.
rand
(
2
,
3
,
16
,
16
)),
datapoints
.
BoundingBox
([
0.0
,
1.0
,
2.0
,
3.0
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
)),
datapoints
.
BoundingBox
es
([
0.0
,
1.0
,
2.0
,
3.0
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
)),
datapoints
.
Mask
(
torch
.
randint
(
0
,
256
,
(
16
,
16
),
dtype
=
torch
.
uint8
)),
],
)
...
...
test/test_prototype_transforms.py
View file @
332bff93
...
...
@@ -20,7 +20,7 @@ from common_utils import (
from
prototype_common_utils
import
make_label
,
make_one_hot_labels
from
torchvision.datapoints
import
BoundingBox
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.datapoints
import
BoundingBox
es
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
_convert_fill_arg
from
torchvision.transforms.v2.functional
import
InterpolationMode
,
pil_to_tensor
,
to_image_pil
...
...
@@ -101,10 +101,10 @@ class TestSimpleCopyPaste:
self
.
create_fake_image
(
mocker
,
Image
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
datapoints
.
Label
),
mocker
.
MagicMock
(
spec
=
BoundingBox
),
mocker
.
MagicMock
(
spec
=
BoundingBox
es
),
mocker
.
MagicMock
(
spec
=
Mask
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
BoundingBox
),
mocker
.
MagicMock
(
spec
=
BoundingBox
es
),
mocker
.
MagicMock
(
spec
=
Mask
),
]
...
...
@@ -122,11 +122,11 @@ class TestSimpleCopyPaste:
self
.
create_fake_image
(
mocker
,
image_type
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
label_type
),
mocker
.
MagicMock
(
spec
=
BoundingBox
),
mocker
.
MagicMock
(
spec
=
BoundingBox
es
),
mocker
.
MagicMock
(
spec
=
Mask
),
# labels, bboxes, masks
mocker
.
MagicMock
(
spec
=
label_type
),
mocker
.
MagicMock
(
spec
=
BoundingBox
),
mocker
.
MagicMock
(
spec
=
BoundingBox
es
),
mocker
.
MagicMock
(
spec
=
Mask
),
]
...
...
@@ -142,7 +142,7 @@ class TestSimpleCopyPaste:
for
target
in
targets
:
for
key
,
type_
in
[
(
"boxes"
,
BoundingBox
),
(
"boxes"
,
BoundingBox
es
),
(
"masks"
,
Mask
),
(
"labels"
,
label_type
),
]:
...
...
@@ -163,7 +163,7 @@ class TestSimpleCopyPaste:
if
label_type
==
datapoints
.
OneHotLabel
:
labels
=
torch
.
nn
.
functional
.
one_hot
(
labels
,
num_classes
=
5
)
target
=
{
"boxes"
:
BoundingBox
(
"boxes"
:
BoundingBox
es
(
torch
.
tensor
([[
2.0
,
3.0
,
8.0
,
9.0
],
[
20.0
,
20.0
,
30.0
,
30.0
]]),
format
=
"XYXY"
,
spatial_size
=
(
32
,
32
)
),
"masks"
:
Mask
(
masks
),
...
...
@@ -178,7 +178,7 @@ class TestSimpleCopyPaste:
if
label_type
==
datapoints
.
OneHotLabel
:
paste_labels
=
torch
.
nn
.
functional
.
one_hot
(
paste_labels
,
num_classes
=
5
)
paste_target
=
{
"boxes"
:
BoundingBox
(
"boxes"
:
BoundingBox
es
(
torch
.
tensor
([[
12.0
,
13.0
,
19.0
,
18.0
],
[
1.0
,
15.0
,
8.0
,
19.0
]]),
format
=
"XYXY"
,
spatial_size
=
(
32
,
32
)
),
"masks"
:
Mask
(
paste_masks
),
...
...
@@ -332,7 +332,7 @@ class TestFixedSizeCrop:
assert_equal
(
output
[
"masks"
],
masks
[
is_valid
])
assert_equal
(
output
[
"labels"
],
labels
[
is_valid
])
def
test__transform_bounding_box_clamping
(
self
,
mocker
):
def
test__transform_bounding_box
es
_clamping
(
self
,
mocker
):
batch_size
=
3
spatial_size
=
(
10
,
10
)
...
...
@@ -349,15 +349,15 @@ class TestFixedSizeCrop:
),
)
bounding_box
=
make_bounding_box
(
bounding_box
es
=
make_bounding_box
(
format
=
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
batch_dims
=
(
batch_size
,)
)
mock
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_box"
)
mock
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_box
es
"
)
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
))
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
transform
(
bounding_box
)
transform
(
bounding_box
es
)
mock
.
assert_called_once
()
...
...
@@ -390,7 +390,7 @@ class TestPermuteDimensions:
def
test_call
(
self
,
dims
,
inverse_dims
):
sample
=
dict
(
image
=
make_image
(),
bounding_box
=
make_bounding_box
(
format
=
BoundingBoxFormat
.
XYXY
),
bounding_box
es
=
make_bounding_box
(
format
=
BoundingBoxFormat
.
XYXY
),
video
=
make_video
(),
str
=
"str"
,
int
=
0
,
...
...
@@ -434,7 +434,7 @@ class TestTransposeDimensions:
def
test_call
(
self
,
dims
):
sample
=
dict
(
image
=
make_image
(),
bounding_box
=
make_bounding_box
(
format
=
BoundingBoxFormat
.
XYXY
),
bounding_box
es
=
make_bounding_box
(
format
=
BoundingBoxFormat
.
XYXY
),
video
=
make_video
(),
str
=
"str"
,
int
=
0
,
...
...
test/test_transforms_v2.py
View file @
332bff93
...
...
@@ -46,8 +46,8 @@ def make_pil_images(*args, **kwargs):
def
make_vanilla_tensor_bounding_boxes
(
*
args
,
**
kwargs
):
for
bounding_box
in
make_bounding_boxes
(
*
args
,
**
kwargs
):
yield
bounding_box
.
data
for
bounding_box
es
in
make_bounding_boxes
(
*
args
,
**
kwargs
):
yield
bounding_box
es
.
data
def
parametrize
(
transforms_with_inputs
):
...
...
@@ -69,7 +69,7 @@ def auto_augment_adapter(transform, input, device):
adapted_input
=
{}
image_or_video_found
=
False
for
key
,
value
in
input
.
items
():
if
isinstance
(
value
,
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
)):
if
isinstance
(
value
,
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
)):
# AA transforms don't support bounding boxes or masks
continue
elif
check_type
(
value
,
(
datapoints
.
Image
,
datapoints
.
Video
,
is_simple_tensor
,
PIL
.
Image
.
Image
)):
...
...
@@ -143,7 +143,7 @@ class TestSmoke:
(
transforms
.
RandomZoomOut
(
p
=
1.0
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
),
antialias
=
True
),
None
),
(
transforms
.
ClampBoundingBox
(),
None
),
(
transforms
.
ClampBoundingBox
es
(),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertImageDtype
(),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
...
...
@@ -180,16 +180,16 @@ class TestSmoke:
image_datapoint
=
make_image
(
size
=
spatial_size
),
video_datapoint
=
make_video
(
size
=
spatial_size
),
image_pil
=
next
(
make_pil_images
(
sizes
=
[
spatial_size
],
color_spaces
=
[
"RGB"
])),
bounding_box_xyxy
=
make_bounding_box
(
bounding_box
es
_xyxy
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
batch_dims
=
(
3
,)
),
bounding_box_xywh
=
make_bounding_box
(
bounding_box
es
_xywh
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
batch_dims
=
(
4
,)
),
bounding_box_cxcywh
=
make_bounding_box
(
bounding_box
es
_cxcywh
=
make_bounding_box
(
format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
,
spatial_size
=
spatial_size
,
batch_dims
=
(
5
,)
),
bounding_box_degenerate_xyxy
=
datapoints
.
BoundingBox
(
bounding_box
es
_degenerate_xyxy
=
datapoints
.
BoundingBox
es
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
...
...
@@ -201,7 +201,7 @@ class TestSmoke:
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
),
bounding_box_degenerate_xywh
=
datapoints
.
BoundingBox
(
bounding_box
es
_degenerate_xywh
=
datapoints
.
BoundingBox
es
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
...
...
@@ -213,7 +213,7 @@ class TestSmoke:
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
),
bounding_box_degenerate_cxcywh
=
datapoints
.
BoundingBox
(
bounding_box
es
_degenerate_cxcywh
=
datapoints
.
BoundingBox
es
(
[
[
0
,
0
,
0
,
0
],
# no height or width
[
0
,
0
,
0
,
1
],
# no height
...
...
@@ -261,7 +261,7 @@ class TestSmoke:
else
:
assert
output_item
is
input_item
if
isinstance
(
input_item
,
datapoints
.
BoundingBox
)
and
not
isinstance
(
if
isinstance
(
input_item
,
datapoints
.
BoundingBox
es
)
and
not
isinstance
(
transform
,
transforms
.
ConvertBoundingBoxFormat
):
assert
output_item
.
format
==
input_item
.
format
...
...
@@ -271,10 +271,10 @@ class TestSmoke:
# TODO: we should test that against all degenerate boxes above
for
format
in
list
(
datapoints
.
BoundingBoxFormat
):
sample
=
dict
(
boxes
=
datapoints
.
BoundingBox
([[
0
,
0
,
0
,
0
]],
format
=
format
,
spatial_size
=
(
224
,
244
)),
boxes
=
datapoints
.
BoundingBox
es
([[
0
,
0
,
0
,
0
]],
format
=
format
,
spatial_size
=
(
224
,
244
)),
labels
=
torch
.
tensor
([
3
]),
)
assert
transforms
.
SanitizeBoundingBox
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
assert
transforms
.
SanitizeBoundingBox
es
()(
sample
)[
"boxes"
].
shape
==
(
0
,
4
)
@
parametrize
(
[
...
...
@@ -942,7 +942,7 @@ class TestRandomErasing:
class
TestTransform
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
es
,
str
,
int
],
)
def
test_check_transformed_types
(
self
,
inpt_type
,
mocker
):
# This test ensures that we correctly handle which types to transform and which to bypass
...
...
@@ -960,7 +960,7 @@ class TestTransform:
class
TestToImageTensor
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
es
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
...
...
@@ -971,7 +971,7 @@ class TestToImageTensor:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImageTensor
()
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
BoundingBox
,
datapoints
.
Image
,
str
,
int
):
if
inpt_type
in
(
datapoints
.
BoundingBox
es
,
datapoints
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
else
:
fn
.
assert_called_once_with
(
inpt
)
...
...
@@ -980,7 +980,7 @@ class TestToImageTensor:
class
TestToImagePIL
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
es
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image_pil"
)
...
...
@@ -988,7 +988,7 @@ class TestToImagePIL:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToImagePIL
()
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
BoundingBox
,
PIL
.
Image
.
Image
,
str
,
int
):
if
inpt_type
in
(
datapoints
.
BoundingBox
es
,
PIL
.
Image
.
Image
,
str
,
int
):
assert
fn
.
call_count
==
0
else
:
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
...
...
@@ -997,7 +997,7 @@ class TestToImagePIL:
class
TestToPILImage
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
es
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.v2.functional.to_image_pil"
)
...
...
@@ -1005,7 +1005,7 @@ class TestToPILImage:
inpt
=
mocker
.
MagicMock
(
spec
=
inpt_type
)
transform
=
transforms
.
ToPILImage
()
transform
(
inpt
)
if
inpt_type
in
(
PIL
.
Image
.
Image
,
datapoints
.
BoundingBox
,
str
,
int
):
if
inpt_type
in
(
PIL
.
Image
.
Image
,
datapoints
.
BoundingBox
es
,
str
,
int
):
assert
fn
.
call_count
==
0
else
:
fn
.
assert_called_once_with
(
inpt
,
mode
=
transform
.
mode
)
...
...
@@ -1014,7 +1014,7 @@ class TestToPILImage:
class
TestToTensor
:
@
pytest
.
mark
.
parametrize
(
"inpt_type"
,
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
,
str
,
int
],
[
torch
.
Tensor
,
PIL
.
Image
.
Image
,
datapoints
.
Image
,
np
.
ndarray
,
datapoints
.
BoundingBox
es
,
str
,
int
],
)
def
test__transform
(
self
,
inpt_type
,
mocker
):
fn
=
mocker
.
patch
(
"torchvision.transforms.functional.to_tensor"
)
...
...
@@ -1023,7 +1023,7 @@ class TestToTensor:
with
pytest
.
warns
(
UserWarning
,
match
=
"deprecated and will be removed"
):
transform
=
transforms
.
ToTensor
()
transform
(
inpt
)
if
inpt_type
in
(
datapoints
.
Image
,
torch
.
Tensor
,
datapoints
.
BoundingBox
,
str
,
int
):
if
inpt_type
in
(
datapoints
.
Image
,
torch
.
Tensor
,
datapoints
.
BoundingBox
es
,
str
,
int
):
assert
fn
.
call_count
==
0
else
:
fn
.
assert_called_once_with
(
inpt
)
...
...
@@ -1065,7 +1065,7 @@ class TestRandomIoUCrop:
image
=
mocker
.
MagicMock
(
spec
=
datapoints
.
Image
)
image
.
num_channels
=
3
image
.
spatial_size
=
(
24
,
32
)
bboxes
=
datapoints
.
BoundingBox
(
bboxes
=
datapoints
.
BoundingBox
es
(
torch
.
tensor
([[
1
,
1
,
10
,
10
],
[
20
,
20
,
23
,
23
],
[
1
,
20
,
10
,
23
],
[
20
,
1
,
23
,
10
]]),
format
=
"XYXY"
,
spatial_size
=
image
.
spatial_size
,
...
...
@@ -1103,7 +1103,7 @@ class TestRandomIoUCrop:
def
test__transform_empty_params
(
self
,
mocker
):
transform
=
transforms
.
RandomIoUCrop
(
sampler_options
=
[
2.0
])
image
=
datapoints
.
Image
(
torch
.
rand
(
1
,
3
,
4
,
4
))
bboxes
=
datapoints
.
BoundingBox
(
torch
.
tensor
([[
1
,
1
,
2
,
2
]]),
format
=
"XYXY"
,
spatial_size
=
(
4
,
4
))
bboxes
=
datapoints
.
BoundingBox
es
(
torch
.
tensor
([[
1
,
1
,
2
,
2
]]),
format
=
"XYXY"
,
spatial_size
=
(
4
,
4
))
label
=
torch
.
tensor
([
1
])
sample
=
[
image
,
bboxes
,
label
]
# Let's mock transform._get_params to control the output:
...
...
@@ -1147,7 +1147,7 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels:
output_bboxes
=
output
[
1
]
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBox
)
assert
isinstance
(
output_bboxes
,
datapoints
.
BoundingBox
es
)
assert
(
output_bboxes
[
~
is_within_crop_area
]
==
0
).
all
()
output_masks
=
output
[
2
]
...
...
@@ -1505,7 +1505,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
transforms
.
ConvertImageDtype
(
torch
.
float
),
]
if
sanitize
:
t
+=
[
transforms
.
SanitizeBoundingBox
()]
t
+=
[
transforms
.
SanitizeBoundingBox
es
()]
t
=
transforms
.
Compose
(
t
)
num_boxes
=
5
...
...
@@ -1523,7 +1523,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes
=
torch
.
randint
(
0
,
min
(
H
,
W
)
//
2
,
size
=
(
num_boxes
,
4
))
boxes
[:,
2
:]
+=
boxes
[:,
:
2
]
boxes
=
boxes
.
clamp
(
min
=
0
,
max
=
min
(
H
,
W
))
boxes
=
datapoints
.
BoundingBox
(
boxes
,
format
=
"XYXY"
,
spatial_size
=
(
H
,
W
))
boxes
=
datapoints
.
BoundingBox
es
(
boxes
,
format
=
"XYXY"
,
spatial_size
=
(
H
,
W
))
masks
=
datapoints
.
Mask
(
torch
.
randint
(
0
,
2
,
size
=
(
num_boxes
,
H
,
W
),
dtype
=
torch
.
uint8
))
...
...
@@ -1546,7 +1546,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# SanitizeBoundingBox(), which we add to the pipelines if the sanitize
# SanitizeBoundingBox
es
(), which we add to the pipelines if the sanitize
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
...
...
@@ -1594,7 +1594,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes
=
torch
.
tensor
(
boxes
)
labels
=
torch
.
arange
(
boxes
.
shape
[
0
])
boxes
=
datapoints
.
BoundingBox
(
boxes
=
datapoints
.
BoundingBox
es
(
boxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
H
,
W
),
...
...
@@ -1616,7 +1616,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
img
=
sample
.
pop
(
"image"
)
sample
=
(
img
,
sample
)
out
=
transforms
.
SanitizeBoundingBox
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
out
=
transforms
.
SanitizeBoundingBox
es
(
min_size
=
min_size
,
labels_getter
=
labels_getter
)(
sample
)
if
sample_type
is
tuple
:
out_image
=
out
[
0
]
...
...
@@ -1634,7 +1634,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert
out_image
is
input_img
assert
out_whatever
is
whatever
assert
isinstance
(
out_boxes
,
datapoints
.
BoundingBox
)
assert
isinstance
(
out_boxes
,
datapoints
.
BoundingBox
es
)
assert
isinstance
(
out_masks
,
datapoints
.
Mask
)
if
labels_getter
is
None
or
(
callable
(
labels_getter
)
and
labels_getter
({
"labels"
:
"blah"
})
is
None
):
...
...
@@ -1648,31 +1648,31 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
def
test_sanitize_bounding_boxes_errors
():
good_bbox
=
datapoints
.
BoundingBox
(
good_bbox
=
datapoints
.
BoundingBox
es
(
[[
0
,
0
,
10
,
10
]],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
20
,
20
),
)
with
pytest
.
raises
(
ValueError
,
match
=
"min_size must be >= 1"
):
transforms
.
SanitizeBoundingBox
(
min_size
=
0
)
transforms
.
SanitizeBoundingBox
es
(
min_size
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"labels_getter should either be 'default'"
):
transforms
.
SanitizeBoundingBox
(
labels_getter
=
12
)
transforms
.
SanitizeBoundingBox
es
(
labels_getter
=
12
)
with
pytest
.
raises
(
ValueError
,
match
=
"Could not infer where the labels are"
):
bad_labels_key
=
{
"bbox"
:
good_bbox
,
"BAD_KEY"
:
torch
.
arange
(
good_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBox
()(
bad_labels_key
)
transforms
.
SanitizeBoundingBox
es
()(
bad_labels_key
)
with
pytest
.
raises
(
ValueError
,
match
=
"must be a tensor"
):
not_a_tensor
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]).
tolist
()}
transforms
.
SanitizeBoundingBox
()(
not_a_tensor
)
transforms
.
SanitizeBoundingBox
es
()(
not_a_tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"Number of boxes"
):
different_sizes
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]
+
3
)}
transforms
.
SanitizeBoundingBox
()(
different_sizes
)
transforms
.
SanitizeBoundingBox
es
()(
different_sizes
)
with
pytest
.
raises
(
ValueError
,
match
=
"boxes must be of shape"
):
bad_bbox
=
datapoints
.
BoundingBox
(
# batch with 2 elements
bad_bbox
=
datapoints
.
BoundingBox
es
(
# batch with 2 elements
[
[[
0
,
0
,
10
,
10
]],
[[
0
,
0
,
10
,
10
]],
...
...
@@ -1681,7 +1681,7 @@ def test_sanitize_bounding_boxes_errors():
spatial_size
=
(
20
,
20
),
)
different_sizes
=
{
"bbox"
:
bad_bbox
,
"labels"
:
torch
.
arange
(
bad_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBox
()(
different_sizes
)
transforms
.
SanitizeBoundingBox
es
()(
different_sizes
)
@
pytest
.
mark
.
parametrize
(
...
...
test/test_transforms_v2_consistency.py
View file @
332bff93
...
...
@@ -1127,7 +1127,7 @@ class TestRefDetTransforms:
v2_transforms
.
Compose
(
[
v2_transforms
.
RandomIoUCrop
(),
v2_transforms
.
SanitizeBoundingBox
(
labels_getter
=
lambda
sample
:
sample
[
1
][
"labels"
]),
v2_transforms
.
SanitizeBoundingBox
es
(
labels_getter
=
lambda
sample
:
sample
[
1
][
"labels"
]),
]
),
{
"with_mask"
:
False
},
...
...
test/test_transforms_v2_functional.py
View file @
332bff93
...
...
@@ -26,7 +26,7 @@ from torchvision import datapoints
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.v2
import
functional
as
F
from
torchvision.transforms.v2.functional._geometry
import
_center_crop_compute_padding
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_box
,
convert_format_bounding_box
from
torchvision.transforms.v2.functional._meta
import
clamp_bounding_box
es
,
convert_format_bounding_box
es
from
torchvision.transforms.v2.utils
import
is_simple_tensor
from
transforms_v2_dispatcher_infos
import
DISPATCHER_INFOS
from
transforms_v2_kernel_infos
import
KERNEL_INFOS
...
...
@@ -176,7 +176,7 @@ class TestKernels:
# Everything to the left is considered a batch dimension.
data_dims
=
{
datapoints
.
Image
:
3
,
datapoints
.
BoundingBox
:
1
,
datapoints
.
BoundingBox
es
:
1
,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
...
...
@@ -515,15 +515,15 @@ class TestDispatchers:
[
info
for
info
in
DISPATCHER_INFOS
if
datapoints
.
BoundingBox
in
info
.
kernels
and
info
.
dispatcher
is
not
F
.
convert_format_bounding_box
if
datapoints
.
BoundingBox
es
in
info
.
kernels
and
info
.
dispatcher
is
not
F
.
convert_format_bounding_box
es
],
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
BoundingBox
),
args_kwargs_fn
=
lambda
info
:
info
.
sample_inputs
(
datapoints
.
BoundingBox
es
),
)
def
test_bounding_box_format_consistency
(
self
,
info
,
args_kwargs
):
(
bounding_box
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
format
=
bounding_box
.
format
def
test_bounding_box
es
_format_consistency
(
self
,
info
,
args_kwargs
):
(
bounding_box
es
,
*
other_args
),
kwargs
=
args_kwargs
.
load
()
format
=
bounding_box
es
.
format
output
=
info
.
dispatcher
(
bounding_box
,
*
other_args
,
**
kwargs
)
output
=
info
.
dispatcher
(
bounding_box
es
,
*
other_args
,
**
kwargs
)
assert
output
.
format
==
format
...
...
@@ -562,7 +562,7 @@ def test_normalize_image_tensor_stats(device, num_channels):
assert_samples_from_standard_normal
(
F
.
normalize_image_tensor
(
image
,
mean
,
std
))
class
TestClampBoundingBox
:
class
TestClampBoundingBox
es
:
@
pytest
.
mark
.
parametrize
(
"metadata"
,
[
...
...
@@ -575,7 +575,7 @@ class TestClampBoundingBox:
simple_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `spatial_size` has to be passed"
)):
F
.
clamp_bounding_box
(
simple_tensor
,
**
metadata
)
F
.
clamp_bounding_box
es
(
simple_tensor
,
**
metadata
)
@
pytest
.
mark
.
parametrize
(
"metadata"
,
...
...
@@ -589,10 +589,10 @@ class TestClampBoundingBox:
datapoint
=
next
(
make_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `spatial_size` must not be passed"
)):
F
.
clamp_bounding_box
(
datapoint
,
**
metadata
)
F
.
clamp_bounding_box
es
(
datapoint
,
**
metadata
)
class
TestConvertFormatBoundingBox
:
class
TestConvertFormatBoundingBox
es
:
@
pytest
.
mark
.
parametrize
(
(
"inpt"
,
"old_format"
),
[
...
...
@@ -602,19 +602,19 @@ class TestConvertFormatBoundingBox:
)
def
test_missing_new_format
(
self
,
inpt
,
old_format
):
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
"missing 1 required argument: 'new_format'"
)):
F
.
convert_format_bounding_box
(
inpt
,
old_format
)
F
.
convert_format_bounding_box
es
(
inpt
,
old_format
)
def
test_simple_tensor_insufficient_metadata
(
self
):
simple_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
F
.
convert_format_bounding_box
(
simple_tensor
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
F
.
convert_format_bounding_box
es
(
simple_tensor
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
def
test_datapoint_explicit_metadata
(
self
):
datapoint
=
next
(
make_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` must not be passed"
)):
F
.
convert_format_bounding_box
(
F
.
convert_format_bounding_box
es
(
datapoint
,
old_format
=
datapoint
.
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
...
...
@@ -658,7 +658,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
[
-
8
,
12
,
70
,
40
,
[(
-
2.0
,
23.0
,
13.0
,
43.0
),
(
38.0
,
13.0
,
58.0
,
30.0
),
(
33.0
,
54.0
,
44.0
,
70.0
)]],
],
)
def
test_correctness_crop_bounding_box
(
device
,
format
,
top
,
left
,
height
,
width
,
expected_bboxes
):
def
test_correctness_crop_bounding_box
es
(
device
,
format
,
top
,
left
,
height
,
width
,
expected_bboxes
):
# Expected bboxes computed using Albumentations:
# import numpy as np
...
...
@@ -681,13 +681,13 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
]
in_boxes
=
torch
.
tensor
(
in_boxes
,
device
=
device
)
if
format
!=
datapoints
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_format_bounding_box
(
in_boxes
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_format_bounding_box
es
(
in_boxes
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
)
expected_bboxes
=
clamp_bounding_box
(
datapoints
.
BoundingBox
(
expected_bboxes
,
format
=
"XYXY"
,
spatial_size
=
spatial_size
)
expected_bboxes
=
clamp_bounding_box
es
(
datapoints
.
BoundingBox
es
(
expected_bboxes
,
format
=
"XYXY"
,
spatial_size
=
spatial_size
)
).
tolist
()
output_boxes
,
output_spatial_size
=
F
.
crop_bounding_box
(
output_boxes
,
output_spatial_size
=
F
.
crop_bounding_box
es
(
in_boxes
,
format
,
top
,
...
...
@@ -697,7 +697,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
)
if
format
!=
datapoints
.
BoundingBoxFormat
.
XYXY
:
output_boxes
=
convert_format_bounding_box
(
output_boxes
,
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
)
output_boxes
=
convert_format_bounding_box
es
(
output_boxes
,
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
)
torch
.
testing
.
assert_close
(
output_boxes
.
tolist
(),
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_spatial_size
,
spatial_size
)
...
...
@@ -727,7 +727,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
[
-
5
,
5
,
35
,
45
,
(
32
,
34
)],
],
)
def
test_correctness_resized_crop_bounding_box
(
device
,
format
,
top
,
left
,
height
,
width
,
size
):
def
test_correctness_resized_crop_bounding_box
es
(
device
,
format
,
top
,
left
,
height
,
width
,
size
):
def
_compute_expected_bbox
(
bbox
,
top_
,
left_
,
height_
,
width_
,
size_
):
# bbox should be xyxy
bbox
[
0
]
=
(
bbox
[
0
]
-
left_
)
*
size_
[
1
]
/
width_
...
...
@@ -747,16 +747,16 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
expected_bboxes
.
append
(
_compute_expected_bbox
(
list
(
in_box
),
top
,
left
,
height
,
width
,
size
))
expected_bboxes
=
torch
.
tensor
(
expected_bboxes
,
device
=
device
)
in_boxes
=
datapoints
.
BoundingBox
(
in_boxes
=
datapoints
.
BoundingBox
es
(
in_boxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
device
=
device
)
if
format
!=
datapoints
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_format_bounding_box
(
in_boxes
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_format_bounding_box
es
(
in_boxes
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
)
output_boxes
,
output_spatial_size
=
F
.
resized_crop_bounding_box
(
in_boxes
,
format
,
top
,
left
,
height
,
width
,
size
)
output_boxes
,
output_spatial_size
=
F
.
resized_crop_bounding_box
es
(
in_boxes
,
format
,
top
,
left
,
height
,
width
,
size
)
if
format
!=
datapoints
.
BoundingBoxFormat
.
XYXY
:
output_boxes
=
convert_format_bounding_box
(
output_boxes
,
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
)
output_boxes
=
convert_format_bounding_box
es
(
output_boxes
,
format
,
datapoints
.
BoundingBoxFormat
.
XYXY
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_spatial_size
,
size
)
...
...
@@ -776,7 +776,7 @@ def _parse_padding(padding):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"padding"
,
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]])
def
test_correctness_pad_bounding_box
(
device
,
padding
):
def
test_correctness_pad_bounding_box
es
(
device
,
padding
):
def
_compute_expected_bbox
(
bbox
,
padding_
):
pad_left
,
pad_up
,
_
,
_
=
_parse_padding
(
padding_
)
...
...
@@ -785,13 +785,13 @@ def test_correctness_pad_bounding_box(device, padding):
bbox
=
(
bbox
.
clone
()
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
else
convert_format_bounding_box
(
bbox
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
else
convert_format_bounding_box
es
(
bbox
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
)
bbox
[
0
::
2
]
+=
pad_left
bbox
[
1
::
2
]
+=
pad_up
bbox
=
convert_format_bounding_box
(
bbox
,
new_format
=
format
)
bbox
=
convert_format_bounding_box
es
(
bbox
,
new_format
=
format
)
if
bbox
.
dtype
!=
dtype
:
# Temporary cast to original dtype
# e.g. float32 -> int
...
...
@@ -808,7 +808,7 @@ def test_correctness_pad_bounding_box(device, padding):
bboxes_format
=
bboxes
.
format
bboxes_spatial_size
=
bboxes
.
spatial_size
output_boxes
,
output_spatial_size
=
F
.
pad_bounding_box
(
output_boxes
,
output_spatial_size
=
F
.
pad_bounding_box
es
(
bboxes
,
format
=
bboxes_format
,
spatial_size
=
bboxes_spatial_size
,
padding
=
padding
)
...
...
@@ -819,7 +819,7 @@ def test_correctness_pad_bounding_box(device, padding):
expected_bboxes
=
[]
for
bbox
in
bboxes
:
bbox
=
datapoints
.
BoundingBox
(
bbox
,
format
=
bboxes_format
,
spatial_size
=
bboxes_spatial_size
)
bbox
=
datapoints
.
BoundingBox
es
(
bbox
,
format
=
bboxes_format
,
spatial_size
=
bboxes_spatial_size
)
expected_bboxes
.
append
(
_compute_expected_bbox
(
bbox
,
padding
))
if
len
(
expected_bboxes
)
>
1
:
...
...
@@ -849,7 +849,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
[[[
3
,
2
],
[
32
,
3
],
[
30
,
24
],
[
2
,
25
]],
[[
5
,
5
],
[
30
,
3
],
[
33
,
19
],
[
4
,
25
]]],
],
)
def
test_correctness_perspective_bounding_box
(
device
,
startpoints
,
endpoints
):
def
test_correctness_perspective_bounding_box
es
(
device
,
startpoints
,
endpoints
):
def
_compute_expected_bbox
(
bbox
,
pcoeffs_
):
m1
=
np
.
array
(
[
...
...
@@ -864,7 +864,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
]
)
bbox_xyxy
=
convert_format_bounding_box
(
bbox
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
bbox_xyxy
=
convert_format_bounding_box
es
(
bbox
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
points
=
np
.
array
(
[
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
...
...
@@ -884,14 +884,14 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
np
.
max
(
transformed_points
[:,
1
]),
]
)
out_bbox
=
datapoints
.
BoundingBox
(
out_bbox
=
datapoints
.
BoundingBox
es
(
out_bbox
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
bbox
.
spatial_size
,
dtype
=
bbox
.
dtype
,
device
=
bbox
.
device
,
)
return
clamp_bounding_box
(
convert_format_bounding_box
(
out_bbox
,
new_format
=
bbox
.
format
))
return
clamp_bounding_box
es
(
convert_format_bounding_box
es
(
out_bbox
,
new_format
=
bbox
.
format
))
spatial_size
=
(
32
,
38
)
...
...
@@ -901,7 +901,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
for
bboxes
in
make_bounding_boxes
(
spatial_size
=
spatial_size
,
extra_dims
=
((
4
,),)):
bboxes
=
bboxes
.
to
(
device
)
output_bboxes
=
F
.
perspective_bounding_box
(
output_bboxes
=
F
.
perspective_bounding_box
es
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
format
=
bboxes
.
format
,
spatial_size
=
bboxes
.
spatial_size
,
...
...
@@ -915,7 +915,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
expected_bboxes
=
[]
for
bbox
in
bboxes
:
bbox
=
datapoints
.
BoundingBox
(
bbox
,
format
=
bboxes
.
format
,
spatial_size
=
bboxes
.
spatial_size
)
bbox
=
datapoints
.
BoundingBox
es
(
bbox
,
format
=
bboxes
.
format
,
spatial_size
=
bboxes
.
spatial_size
)
expected_bboxes
.
append
(
_compute_expected_bbox
(
bbox
,
inv_pcoeffs
))
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
...
...
@@ -929,12 +929,12 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
"output_size"
,
[(
18
,
18
),
[
18
,
15
],
(
16
,
19
),
[
12
],
[
46
,
48
]],
)
def
test_correctness_center_crop_bounding_box
(
device
,
output_size
):
def
test_correctness_center_crop_bounding_box
es
(
device
,
output_size
):
def
_compute_expected_bbox
(
bbox
,
output_size_
):
format_
=
bbox
.
format
spatial_size_
=
bbox
.
spatial_size
dtype
=
bbox
.
dtype
bbox
=
convert_format_bounding_box
(
bbox
.
float
(),
format_
,
datapoints
.
BoundingBoxFormat
.
XYWH
)
bbox
=
convert_format_bounding_box
es
(
bbox
.
float
(),
format_
,
datapoints
.
BoundingBoxFormat
.
XYWH
)
if
len
(
output_size_
)
==
1
:
output_size_
.
append
(
output_size_
[
-
1
])
...
...
@@ -948,8 +948,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bbox
[
3
].
item
(),
]
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
convert_format_bounding_box
(
out_bbox
,
datapoints
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
clamp_bounding_box
(
out_bbox
,
format
=
format_
,
spatial_size
=
output_size
)
out_bbox
=
convert_format_bounding_box
es
(
out_bbox
,
datapoints
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
clamp_bounding_box
es
(
out_bbox
,
format
=
format_
,
spatial_size
=
output_size
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
for
bboxes
in
make_bounding_boxes
(
extra_dims
=
((
4
,),)):
...
...
@@ -957,7 +957,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bboxes_format
=
bboxes
.
format
bboxes_spatial_size
=
bboxes
.
spatial_size
output_boxes
,
output_spatial_size
=
F
.
center_crop_bounding_box
(
output_boxes
,
output_spatial_size
=
F
.
center_crop_bounding_box
es
(
bboxes
,
bboxes_format
,
bboxes_spatial_size
,
output_size
)
...
...
@@ -966,7 +966,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes
=
[]
for
bbox
in
bboxes
:
bbox
=
datapoints
.
BoundingBox
(
bbox
,
format
=
bboxes_format
,
spatial_size
=
bboxes_spatial_size
)
bbox
=
datapoints
.
BoundingBox
es
(
bbox
,
format
=
bboxes_format
,
spatial_size
=
bboxes_spatial_size
)
expected_bboxes
.
append
(
_compute_expected_bbox
(
bbox
,
output_size
))
if
len
(
expected_bboxes
)
>
1
:
...
...
test/test_transforms_v2_refactored.py
View file @
332bff93
This diff is collapsed.
Click to expand it.
test/test_transforms_v2_utils.py
View file @
332bff93
...
...
@@ -20,20 +20,20 @@ MASK = make_detection_mask(size=IMAGE.spatial_size)
(
"sample"
,
"types"
,
"expected"
),
[
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
es
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
True
),
((
MASK
,),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
True
),
((
MASK
,),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
),
False
),
((
BOUNDING_BOX
,),
(
datapoints
.
Image
,
datapoints
.
Mask
),
False
),
((
IMAGE
,),
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
False
),
((
IMAGE
,),
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
False
),
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
True
,
),
((),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
False
),
((),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
datapoints
.
Image
),),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
True
,),
True
),
...
...
@@ -58,30 +58,30 @@ def test_has_any(sample, types, expected):
(
"sample"
,
"types"
,
"expected"
),
[
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
es
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Mask
,),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
True
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
True
),
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
True
,
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
False
),
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
True
,
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
),
False
),
((
BOUNDING_BOX
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
MASK
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
False
),
((
IMAGE
,
BOUNDING_BOX
),
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
),
False
),
(
(
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
obj
:
isinstance
(
obj
,
(
datapoints
.
Image
,
datapoints
.
BoundingBox
,
datapoints
.
Mask
)),),
(
lambda
obj
:
isinstance
(
obj
,
(
datapoints
.
Image
,
datapoints
.
BoundingBox
es
,
datapoints
.
Mask
)),),
True
,
),
((
IMAGE
,
BOUNDING_BOX
,
MASK
),
(
lambda
_
:
False
,),
False
),
...
...
test/transforms_v2_dispatcher_infos.py
View file @
332bff93
...
...
@@ -143,7 +143,7 @@ DISPATCHER_INFOS = [
kernels
=
{
datapoints
.
Image
:
F
.
crop_image_tensor
,
datapoints
.
Video
:
F
.
crop_video
,
datapoints
.
BoundingBox
:
F
.
crop_bounding_box
,
datapoints
.
BoundingBox
es
:
F
.
crop_bounding_box
es
,
datapoints
.
Mask
:
F
.
crop_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
crop_image_pil
,
kernel_name
=
"crop_image_pil"
),
...
...
@@ -153,7 +153,7 @@ DISPATCHER_INFOS = [
kernels
=
{
datapoints
.
Image
:
F
.
resized_crop_image_tensor
,
datapoints
.
Video
:
F
.
resized_crop_video
,
datapoints
.
BoundingBox
:
F
.
resized_crop_bounding_box
,
datapoints
.
BoundingBox
es
:
F
.
resized_crop_bounding_box
es
,
datapoints
.
Mask
:
F
.
resized_crop_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
resized_crop_image_pil
),
...
...
@@ -163,7 +163,7 @@ DISPATCHER_INFOS = [
kernels
=
{
datapoints
.
Image
:
F
.
pad_image_tensor
,
datapoints
.
Video
:
F
.
pad_video
,
datapoints
.
BoundingBox
:
F
.
pad_bounding_box
,
datapoints
.
BoundingBox
es
:
F
.
pad_bounding_box
es
,
datapoints
.
Mask
:
F
.
pad_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
pad_image_pil
,
kernel_name
=
"pad_image_pil"
),
...
...
@@ -185,7 +185,7 @@ DISPATCHER_INFOS = [
kernels
=
{
datapoints
.
Image
:
F
.
perspective_image_tensor
,
datapoints
.
Video
:
F
.
perspective_video
,
datapoints
.
BoundingBox
:
F
.
perspective_bounding_box
,
datapoints
.
BoundingBox
es
:
F
.
perspective_bounding_box
es
,
datapoints
.
Mask
:
F
.
perspective_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
perspective_image_pil
),
...
...
@@ -199,7 +199,7 @@ DISPATCHER_INFOS = [
kernels
=
{
datapoints
.
Image
:
F
.
elastic_image_tensor
,
datapoints
.
Video
:
F
.
elastic_video
,
datapoints
.
BoundingBox
:
F
.
elastic_bounding_box
,
datapoints
.
BoundingBox
es
:
F
.
elastic_bounding_box
es
,
datapoints
.
Mask
:
F
.
elastic_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
elastic_image_pil
),
...
...
@@ -210,7 +210,7 @@ DISPATCHER_INFOS = [
kernels
=
{
datapoints
.
Image
:
F
.
center_crop_image_tensor
,
datapoints
.
Video
:
F
.
center_crop_video
,
datapoints
.
BoundingBox
:
F
.
center_crop_bounding_box
,
datapoints
.
BoundingBox
es
:
F
.
center_crop_bounding_box
es
,
datapoints
.
Mask
:
F
.
center_crop_mask
,
},
pil_kernel_info
=
PILKernelInfo
(
F
.
center_crop_image_pil
),
...
...
@@ -374,15 +374,15 @@ DISPATCHER_INFOS = [
],
),
DispatcherInfo
(
F
.
clamp_bounding_box
,
kernels
=
{
datapoints
.
BoundingBox
:
F
.
clamp_bounding_box
},
F
.
clamp_bounding_box
es
,
kernels
=
{
datapoints
.
BoundingBox
es
:
F
.
clamp_bounding_box
es
},
test_marks
=
[
skip_dispatch_datapoint
,
],
),
DispatcherInfo
(
F
.
convert_format_bounding_box
,
kernels
=
{
datapoints
.
BoundingBox
:
F
.
convert_format_bounding_box
},
F
.
convert_format_bounding_box
es
,
kernels
=
{
datapoints
.
BoundingBox
es
:
F
.
convert_format_bounding_box
es
},
test_marks
=
[
skip_dispatch_datapoint
,
],
...
...
test/transforms_v2_kernel_infos.py
View file @
332bff93
...
...
@@ -184,13 +184,13 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs):
return
other_args
,
dict
(
kwargs
,
fill
=
fill
)
def
reference_affine_bounding_box_helper
(
bounding_box
,
*
,
format
,
spatial_size
,
affine_matrix
):
def
reference_affine_bounding_box
es
_helper
(
bounding_box
es
,
*
,
format
,
spatial_size
,
affine_matrix
):
def
transform
(
bbox
,
affine_matrix_
,
format_
,
spatial_size_
):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype
=
bbox
.
dtype
if
not
torch
.
is_floating_point
(
bbox
):
bbox
=
bbox
.
float
()
bbox_xyxy
=
F
.
convert_format_bounding_box
(
bbox_xyxy
=
F
.
convert_format_bounding_box
es
(
bbox
.
as_subclass
(
torch
.
Tensor
),
old_format
=
format_
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
...
...
@@ -214,18 +214,18 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size,
],
dtype
=
bbox_xyxy
.
dtype
,
)
out_bbox
=
F
.
convert_format_bounding_box
(
out_bbox
=
F
.
convert_format_bounding_box
es
(
out_bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox
=
F
.
clamp_bounding_box
(
out_bbox
,
format
=
format_
,
spatial_size
=
spatial_size_
)
out_bbox
=
F
.
clamp_bounding_box
es
(
out_bbox
,
format
=
format_
,
spatial_size
=
spatial_size_
)
out_bbox
=
out_bbox
.
to
(
dtype
=
in_dtype
)
return
out_bbox
if
bounding_box
.
ndim
<
2
:
bounding_box
=
[
bounding_box
]
if
bounding_box
es
.
ndim
<
2
:
bounding_box
es
=
[
bounding_box
es
]
expected_bboxes
=
[
transform
(
bbox
,
affine_matrix
,
format
,
spatial_size
)
for
bbox
in
bounding_box
]
expected_bboxes
=
[
transform
(
bbox
,
affine_matrix
,
format
,
spatial_size
)
for
bbox
in
bounding_box
es
]
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
...
...
@@ -234,30 +234,30 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size,
return
expected_bboxes
def
sample_inputs_convert_format_bounding_box
():
def
sample_inputs_convert_format_bounding_box
es
():
formats
=
list
(
datapoints
.
BoundingBoxFormat
)
for
bounding_box_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
yield
ArgsKwargs
(
bounding_box_loader
,
old_format
=
bounding_box_loader
.
format
,
new_format
=
new_format
)
for
bounding_box
es
_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
yield
ArgsKwargs
(
bounding_box
es
_loader
,
old_format
=
bounding_box
es
_loader
.
format
,
new_format
=
new_format
)
def
reference_convert_format_bounding_box
(
bounding_box
,
old_format
,
new_format
):
def
reference_convert_format_bounding_box
es
(
bounding_box
es
,
old_format
,
new_format
):
return
torchvision
.
ops
.
box_convert
(
bounding_box
,
in_fmt
=
old_format
.
name
.
lower
(),
out_fmt
=
new_format
.
name
.
lower
()
).
to
(
bounding_box
.
dtype
)
bounding_box
es
,
in_fmt
=
old_format
.
name
.
lower
(),
out_fmt
=
new_format
.
name
.
lower
()
).
to
(
bounding_box
es
.
dtype
)
def
reference_inputs_convert_format_bounding_box
():
for
args_kwargs
in
sample_inputs_convert_format_bounding_box
():
def
reference_inputs_convert_format_bounding_box
es
():
for
args_kwargs
in
sample_inputs_convert_format_bounding_box
es
():
if
len
(
args_kwargs
.
args
[
0
].
shape
)
==
2
:
yield
args_kwargs
KERNEL_INFOS
.
append
(
KernelInfo
(
F
.
convert_format_bounding_box
,
sample_inputs_fn
=
sample_inputs_convert_format_bounding_box
,
reference_fn
=
reference_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
F
.
convert_format_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_convert_format_bounding_box
es
,
reference_fn
=
reference_convert_format_bounding_box
es
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
es
,
logs_usage
=
True
,
closeness_kwargs
=
{
((
"TestKernels"
,
"test_against_reference"
),
torch
.
int64
,
"cpu"
):
dict
(
atol
=
1
,
rtol
=
0
),
...
...
@@ -290,11 +290,11 @@ def reference_inputs_crop_image_tensor():
yield
ArgsKwargs
(
image_loader
,
**
params
)
def
sample_inputs_crop_bounding_box
():
for
bounding_box_loader
,
params
in
itertools
.
product
(
def
sample_inputs_crop_bounding_box
es
():
for
bounding_box
es
_loader
,
params
in
itertools
.
product
(
make_bounding_box_loaders
(),
[
_CROP_PARAMS
[
0
],
_CROP_PARAMS
[
-
1
]]
):
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
**
params
)
yield
ArgsKwargs
(
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
**
params
)
def
sample_inputs_crop_mask
():
...
...
@@ -312,27 +312,27 @@ def sample_inputs_crop_video():
yield
ArgsKwargs
(
video_loader
,
top
=
4
,
left
=
3
,
height
=
7
,
width
=
8
)
def
reference_crop_bounding_box
(
bounding_box
,
*
,
format
,
top
,
left
,
height
,
width
):
def
reference_crop_bounding_box
es
(
bounding_box
es
,
*
,
format
,
top
,
left
,
height
,
width
):
affine_matrix
=
np
.
array
(
[
[
1
,
0
,
-
left
],
[
0
,
1
,
-
top
],
],
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
dtype
=
"float64"
if
bounding_box
es
.
dtype
==
torch
.
float64
else
"float32"
,
)
spatial_size
=
(
height
,
width
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
expected_bboxes
=
reference_affine_bounding_box
es
_helper
(
bounding_box
es
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
)
return
expected_bboxes
,
spatial_size
def
reference_inputs_crop_bounding_box
():
for
bounding_box_loader
,
params
in
itertools
.
product
(
def
reference_inputs_crop_bounding_box
es
():
for
bounding_box
es
_loader
,
params
in
itertools
.
product
(
make_bounding_box_loaders
(
extra_dims
=
((),
(
4
,))),
[
_CROP_PARAMS
[
0
],
_CROP_PARAMS
[
-
1
]]
):
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
**
params
)
yield
ArgsKwargs
(
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
**
params
)
KERNEL_INFOS
.
extend
(
...
...
@@ -346,10 +346,10 @@ KERNEL_INFOS.extend(
float32_vs_uint8
=
True
,
),
KernelInfo
(
F
.
crop_bounding_box
,
sample_inputs_fn
=
sample_inputs_crop_bounding_box
,
reference_fn
=
reference_crop_bounding_box
,
reference_inputs_fn
=
reference_inputs_crop_bounding_box
,
F
.
crop_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_crop_bounding_box
es
,
reference_fn
=
reference_crop_bounding_box
es
,
reference_inputs_fn
=
reference_inputs_crop_bounding_box
es
,
),
KernelInfo
(
F
.
crop_mask
,
...
...
@@ -406,9 +406,9 @@ def reference_inputs_resized_crop_image_tensor():
)
def
sample_inputs_resized_crop_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
**
_RESIZED_CROP_PARAMS
[
0
])
def
sample_inputs_resized_crop_bounding_box
es
():
for
bounding_box
es
_loader
in
make_bounding_box_loaders
():
yield
ArgsKwargs
(
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
**
_RESIZED_CROP_PARAMS
[
0
])
def
sample_inputs_resized_crop_mask
():
...
...
@@ -436,8 +436,8 @@ KERNEL_INFOS.extend(
},
),
KernelInfo
(
F
.
resized_crop_bounding_box
,
sample_inputs_fn
=
sample_inputs_resized_crop_bounding_box
,
F
.
resized_crop_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_resized_crop_bounding_box
es
,
),
KernelInfo
(
F
.
resized_crop_mask
,
...
...
@@ -500,14 +500,14 @@ def reference_inputs_pad_image_tensor():
yield
ArgsKwargs
(
image_loader
,
fill
=
fill
,
**
params
)
def
sample_inputs_pad_bounding_box
():
for
bounding_box_loader
,
padding
in
itertools
.
product
(
def
sample_inputs_pad_bounding_box
es
():
for
bounding_box
es
_loader
,
padding
in
itertools
.
product
(
make_bounding_box_loaders
(),
[
1
,
(
1
,),
(
1
,
2
),
(
1
,
2
,
3
,
4
),
[
1
],
[
1
,
2
],
[
1
,
2
,
3
,
4
]]
):
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
spatial_size
=
bounding_box
es
_loader
.
spatial_size
,
padding
=
padding
,
padding_mode
=
"constant"
,
)
...
...
@@ -530,7 +530,7 @@ def sample_inputs_pad_video():
yield
ArgsKwargs
(
video_loader
,
padding
=
[
1
])
def
reference_pad_bounding_box
(
bounding_box
,
*
,
format
,
spatial_size
,
padding
,
padding_mode
):
def
reference_pad_bounding_box
es
(
bounding_box
es
,
*
,
format
,
spatial_size
,
padding
,
padding_mode
):
left
,
right
,
top
,
bottom
=
_parse_pad_padding
(
padding
)
...
...
@@ -539,26 +539,26 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p
[
1
,
0
,
left
],
[
0
,
1
,
top
],
],
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
dtype
=
"float64"
if
bounding_box
es
.
dtype
==
torch
.
float64
else
"float32"
,
)
height
=
spatial_size
[
0
]
+
top
+
bottom
width
=
spatial_size
[
1
]
+
left
+
right
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
(
height
,
width
),
affine_matrix
=
affine_matrix
expected_bboxes
=
reference_affine_bounding_box
es
_helper
(
bounding_box
es
,
format
=
format
,
spatial_size
=
(
height
,
width
),
affine_matrix
=
affine_matrix
)
return
expected_bboxes
,
(
height
,
width
)
def
reference_inputs_pad_bounding_box
():
for
bounding_box_loader
,
padding
in
itertools
.
product
(
def
reference_inputs_pad_bounding_box
es
():
for
bounding_box
es
_loader
,
padding
in
itertools
.
product
(
make_bounding_box_loaders
(
extra_dims
=
((),
(
4
,))),
[
1
,
(
1
,),
(
1
,
2
),
(
1
,
2
,
3
,
4
),
[
1
],
[
1
,
2
],
[
1
,
2
,
3
,
4
]]
):
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
spatial_size
=
bounding_box
es
_loader
.
spatial_size
,
padding
=
padding
,
padding_mode
=
"constant"
,
)
...
...
@@ -591,10 +591,10 @@ KERNEL_INFOS.extend(
],
),
KernelInfo
(
F
.
pad_bounding_box
,
sample_inputs_fn
=
sample_inputs_pad_bounding_box
,
reference_fn
=
reference_pad_bounding_box
,
reference_inputs_fn
=
reference_inputs_pad_bounding_box
,
F
.
pad_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_pad_bounding_box
es
,
reference_fn
=
reference_pad_bounding_box
es
,
reference_inputs_fn
=
reference_inputs_pad_bounding_box
es
,
test_marks
=
[
xfail_jit_python_scalar_arg
(
"padding"
),
],
...
...
@@ -655,12 +655,12 @@ def reference_inputs_perspective_image_tensor():
)
def
sample_inputs_perspective_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
def
sample_inputs_perspective_bounding_box
es
():
for
bounding_box
es
_loader
in
make_bounding_box_loaders
():
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
spatial_size
=
bounding_box
es
_loader
.
spatial_size
,
startpoints
=
None
,
endpoints
=
None
,
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
...
...
@@ -712,8 +712,8 @@ KERNEL_INFOS.extend(
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
),
KernelInfo
(
F
.
perspective_bounding_box
,
sample_inputs_fn
=
sample_inputs_perspective_bounding_box
,
F
.
perspective_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_perspective_bounding_box
es
,
closeness_kwargs
=
{
**
scripted_vs_eager_float64_tolerances
(
"cpu"
,
atol
=
1e-6
,
rtol
=
1e-6
),
**
scripted_vs_eager_float64_tolerances
(
"cuda"
,
atol
=
1e-6
,
rtol
=
1e-6
),
...
...
@@ -767,13 +767,13 @@ def reference_inputs_elastic_image_tensor():
yield
ArgsKwargs
(
image_loader
,
interpolation
=
interpolation
,
displacement
=
displacement
,
fill
=
fill
)
def
sample_inputs_elastic_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
displacement
=
_get_elastic_displacement
(
bounding_box_loader
.
spatial_size
)
def
sample_inputs_elastic_bounding_box
es
():
for
bounding_box
es
_loader
in
make_bounding_box_loaders
():
displacement
=
_get_elastic_displacement
(
bounding_box
es
_loader
.
spatial_size
)
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
spatial_size
=
bounding_box
es
_loader
.
spatial_size
,
displacement
=
displacement
,
)
...
...
@@ -804,8 +804,8 @@ KERNEL_INFOS.extend(
test_marks
=
[
xfail_jit_python_scalar_arg
(
"fill"
)],
),
KernelInfo
(
F
.
elastic_bounding_box
,
sample_inputs_fn
=
sample_inputs_elastic_bounding_box
,
F
.
elastic_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_elastic_bounding_box
es
,
),
KernelInfo
(
F
.
elastic_mask
,
...
...
@@ -845,12 +845,12 @@ def reference_inputs_center_crop_image_tensor():
yield
ArgsKwargs
(
image_loader
,
output_size
=
output_size
)
def
sample_inputs_center_crop_bounding_box
():
for
bounding_box_loader
,
output_size
in
itertools
.
product
(
make_bounding_box_loaders
(),
_CENTER_CROP_OUTPUT_SIZES
):
def
sample_inputs_center_crop_bounding_box
es
():
for
bounding_box
es
_loader
,
output_size
in
itertools
.
product
(
make_bounding_box_loaders
(),
_CENTER_CROP_OUTPUT_SIZES
):
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
spatial_size
=
bounding_box
es
_loader
.
spatial_size
,
output_size
=
output_size
,
)
...
...
@@ -887,8 +887,8 @@ KERNEL_INFOS.extend(
],
),
KernelInfo
(
F
.
center_crop_bounding_box
,
sample_inputs_fn
=
sample_inputs_center_crop_bounding_box
,
F
.
center_crop_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_center_crop_bounding_box
es
,
test_marks
=
[
xfail_jit_python_scalar_arg
(
"output_size"
),
],
...
...
@@ -1482,19 +1482,19 @@ KERNEL_INFOS.extend(
)
def
sample_inputs_clamp_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
def
sample_inputs_clamp_bounding_box
es
():
for
bounding_box
es
_loader
in
make_bounding_box_loaders
():
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
bounding_box
es
_loader
,
format
=
bounding_box
es
_loader
.
format
,
spatial_size
=
bounding_box
es
_loader
.
spatial_size
,
)
KERNEL_INFOS
.
append
(
KernelInfo
(
F
.
clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
F
.
clamp_bounding_box
es
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
es
,
logs_usage
=
True
,
)
)
...
...
torchvision/datapoints/__init__.py
View file @
332bff93
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
._bounding_box
import
BoundingBox
,
BoundingBoxFormat
from
._bounding_box
import
BoundingBox
es
,
BoundingBoxFormat
from
._datapoint
import
_FillType
,
_FillTypeJIT
,
_InputType
,
_InputTypeJIT
from
._image
import
_ImageType
,
_ImageTypeJIT
,
_TensorImageType
,
_TensorImageTypeJIT
,
Image
from
._mask
import
Mask
...
...
torchvision/datapoints/_bounding_box.py
View file @
332bff93
...
...
@@ -24,7 +24,7 @@ class BoundingBoxFormat(Enum):
CXCYWH
=
"CXCYWH"
class
BoundingBox
(
Datapoint
):
class
BoundingBox
es
(
Datapoint
):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
Args:
...
...
@@ -43,11 +43,11 @@ class BoundingBox(Datapoint):
spatial_size
:
Tuple
[
int
,
int
]
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
])
->
BoundingBox
:
bounding_box
=
tensor
.
as_subclass
(
cls
)
bounding_box
.
format
=
format
bounding_box
.
spatial_size
=
spatial_size
return
bounding_box
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
])
->
BoundingBox
es
:
bounding_box
es
=
tensor
.
as_subclass
(
cls
)
bounding_box
es
.
format
=
format
bounding_box
es
.
spatial_size
=
spatial_size
return
bounding_box
es
def
__new__
(
cls
,
...
...
@@ -58,7 +58,7 @@ class BoundingBox(Datapoint):
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
Optional
[
bool
]
=
None
,
)
->
BoundingBox
:
)
->
BoundingBox
es
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
isinstance
(
format
,
str
):
...
...
@@ -69,17 +69,17 @@ class BoundingBox(Datapoint):
@
classmethod
def
wrap_like
(
cls
,
other
:
BoundingBox
,
other
:
BoundingBox
es
,
tensor
:
torch
.
Tensor
,
*
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
spatial_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
BoundingBox
:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBox` from a reference.
)
->
BoundingBox
es
:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBox
es
` from a reference.
Args:
other (BoundingBox): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox`
other (BoundingBox
es
): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBox
es
`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference.
spatial_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
...
...
@@ -98,17 +98,17 @@ class BoundingBox(Datapoint):
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
)
def
horizontal_flip
(
self
)
->
BoundingBox
:
output
=
self
.
_F
.
horizontal_flip_bounding_box
(
def
horizontal_flip
(
self
)
->
BoundingBox
es
:
output
=
self
.
_F
.
horizontal_flip_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
)
def
vertical_flip
(
self
)
->
BoundingBox
:
output
=
self
.
_F
.
vertical_flip_bounding_box
(
def
vertical_flip
(
self
)
->
BoundingBox
es
:
output
=
self
.
_F
.
vertical_flip_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
)
def
resize
(
# type: ignore[override]
self
,
...
...
@@ -116,26 +116,26 @@ class BoundingBox(Datapoint):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
max_size
:
Optional
[
int
]
=
None
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
resize_bounding_box
(
)
->
BoundingBox
es
:
output
,
spatial_size
=
self
.
_F
.
resize_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
spatial_size
=
self
.
spatial_size
,
size
=
size
,
max_size
=
max_size
,
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
crop_bounding_box
(
def
crop
(
self
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
BoundingBox
es
:
output
,
spatial_size
=
self
.
_F
.
crop_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
top
=
top
,
left
=
left
,
height
=
height
,
width
=
width
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
center_crop_bounding_box
(
def
center_crop
(
self
,
output_size
:
List
[
int
])
->
BoundingBox
es
:
output
,
spatial_size
=
self
.
_F
.
center_crop_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
,
output_size
=
output_size
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
def
resized_crop
(
self
,
...
...
@@ -146,26 +146,26 @@ class BoundingBox(Datapoint):
size
:
List
[
int
],
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
resized_crop_bounding_box
(
)
->
BoundingBox
es
:
output
,
spatial_size
=
self
.
_F
.
resized_crop_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
top
,
left
,
height
,
width
,
size
=
size
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
def
pad
(
self
,
padding
:
Union
[
int
,
Sequence
[
int
]],
fill
:
Optional
[
Union
[
int
,
float
,
List
[
float
]]]
=
None
,
padding_mode
:
str
=
"constant"
,
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
pad_bounding_box
(
)
->
BoundingBox
es
:
output
,
spatial_size
=
self
.
_F
.
pad_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
,
padding
=
padding
,
padding_mode
=
padding_mode
,
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
def
rotate
(
self
,
...
...
@@ -174,8 +174,8 @@ class BoundingBox(Datapoint):
expand
:
bool
=
False
,
center
:
Optional
[
List
[
float
]]
=
None
,
fill
:
_FillTypeJIT
=
None
,
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
rotate_bounding_box
(
)
->
BoundingBox
es
:
output
,
spatial_size
=
self
.
_F
.
rotate_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
,
...
...
@@ -183,7 +183,7 @@ class BoundingBox(Datapoint):
expand
=
expand
,
center
=
center
,
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
def
affine
(
self
,
...
...
@@ -194,8 +194,8 @@ class BoundingBox(Datapoint):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
NEAREST
,
fill
:
_FillTypeJIT
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
BoundingBox
:
output
=
self
.
_F
.
affine_bounding_box
(
)
->
BoundingBox
es
:
output
=
self
.
_F
.
affine_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
self
.
spatial_size
,
...
...
@@ -205,7 +205,7 @@ class BoundingBox(Datapoint):
shear
=
shear
,
center
=
center
,
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
)
def
perspective
(
self
,
...
...
@@ -214,8 +214,8 @@ class BoundingBox(Datapoint):
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
)
->
BoundingBox
:
output
=
self
.
_F
.
perspective_bounding_box
(
)
->
BoundingBox
es
:
output
=
self
.
_F
.
perspective_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
,
...
...
@@ -223,15 +223,15 @@ class BoundingBox(Datapoint):
endpoints
=
endpoints
,
coefficients
=
coefficients
,
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
)
def
elastic
(
self
,
displacement
:
torch
.
Tensor
,
interpolation
:
Union
[
InterpolationMode
,
int
]
=
InterpolationMode
.
BILINEAR
,
fill
:
_FillTypeJIT
=
None
,
)
->
BoundingBox
:
output
=
self
.
_F
.
elastic_bounding_box
(
)
->
BoundingBox
es
:
output
=
self
.
_F
.
elastic_bounding_box
es
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
self
.
spatial_size
,
displacement
=
displacement
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
return
BoundingBox
es
.
wrap_like
(
self
,
output
)
torchvision/datapoints/_datapoint.py
View file @
332bff93
...
...
@@ -138,8 +138,8 @@ class Datapoint(torch.Tensor):
# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
# attribute is cleared, so we need to refill it before we return.
# Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
# `BoundingBox.format` and `BoundingBox.spatial_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBox.clone()`.
# `BoundingBox
es
.format` and `BoundingBox
es
.spatial_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBox
es
.clone()`.
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
def
horizontal_flip
(
self
)
->
Datapoint
:
...
...
torchvision/datapoints/_dataset_wrapper.py
View file @
332bff93
...
...
@@ -44,7 +44,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are
preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``.
* :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox
es
` datapoint.
* :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a
dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data
in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is
...
...
@@ -56,7 +56,7 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
a dictionary with the key-value-pairs ``"masks"`` (as :class:`~torchvision.datapoints.Mask` datapoint) and
``"labels"``.
* :class:`~torchvision.datasets.WIDERFace`: The value for key ``"bbox"`` in the target is converted to ``XYXY``
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint.
coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox
es
` datapoint.
Image classification datasets
...
...
@@ -360,8 +360,8 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
target
[
"image_id"
]
=
image_id
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
F
.
convert_format_bounding_box
(
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
F
.
convert_format_bounding_box
es
(
datapoints
.
BoundingBox
es
(
batched_target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
spatial_size
,
...
...
@@ -442,7 +442,7 @@ def voc_detection_wrapper_factory(dataset, target_keys):
target
=
{}
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
datapoints
.
BoundingBox
es
(
[
[
int
(
bndbox
[
part
])
for
part
in
(
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
)]
for
bndbox
in
batched_instances
[
"bndbox"
]
...
...
@@ -481,8 +481,8 @@ def celeba_wrapper_factory(dataset, target_keys):
target
,
target_types
=
dataset
.
target_type
,
type_wrappers
=
{
"bbox"
:
lambda
item
:
F
.
convert_format_bounding_box
(
datapoints
.
BoundingBox
(
"bbox"
:
lambda
item
:
F
.
convert_format_bounding_box
es
(
datapoints
.
BoundingBox
es
(
item
,
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
(
image
.
height
,
image
.
width
),
...
...
@@ -532,7 +532,7 @@ def kitti_wrapper_factory(dataset, target_keys):
target
=
{}
if
"boxes"
in
target_keys
:
target
[
"boxes"
]
=
datapoints
.
BoundingBox
(
target
[
"boxes"
]
=
datapoints
.
BoundingBox
es
(
batched_target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
image
.
height
,
image
.
width
),
...
...
@@ -628,8 +628,8 @@ def widerface_wrapper(dataset, target_keys):
target
=
{
key
:
target
[
key
]
for
key
in
target_keys
}
if
"bbox"
in
target_keys
:
target
[
"bbox"
]
=
F
.
convert_format_bounding_box
(
datapoints
.
BoundingBox
(
target
[
"bbox"
]
=
F
.
convert_format_bounding_box
es
(
datapoints
.
BoundingBox
es
(
target
[
"bbox"
],
format
=
datapoints
.
BoundingBoxFormat
.
XYWH
,
spatial_size
=
(
image
.
height
,
image
.
width
)
),
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment