Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
8faa1b14
Unverified
Commit
8faa1b14
authored
Aug 08, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 08, 2023
Browse files
Simplify query_bounding_boxes logic (#7786)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
9b82df43
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
67 additions
and
97 deletions
+67
-97
test/common_utils.py
test/common_utils.py
+3
-3
test/test_datapoints.py
test/test_datapoints.py
+7
-1
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+4
-2
test/test_transforms_v2.py
test/test_transforms_v2.py
+0
-12
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+28
-50
test/transforms_v2_kernel_infos.py
test/transforms_v2_kernel_infos.py
+3
-10
torchvision/datapoints/_bounding_box.py
torchvision/datapoints/_bounding_box.py
+10
-0
torchvision/prototype/transforms/_geometry.py
torchvision/prototype/transforms/_geometry.py
+2
-2
torchvision/transforms/v2/_geometry.py
torchvision/transforms/v2/_geometry.py
+2
-2
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+2
-8
torchvision/transforms/v2/utils.py
torchvision/transforms/v2/utils.py
+6
-7
No files found.
test/common_utils.py
View file @
8faa1b14
...
@@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
...
@@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
datapoints
.
BoundingBoxFormat
[
format
]
format
=
datapoints
.
BoundingBoxFormat
[
format
]
spatial_size
=
_parse_size
(
spatial_size
,
name
=
"
canvas
_size"
)
spatial_size
=
_parse_size
(
spatial_size
,
name
=
"
spatial
_size"
)
def
fn
(
shape
,
dtype
,
device
):
def
fn
(
shape
,
dtype
,
device
):
*
batch_dims
,
num_coordinates
=
shape
*
batch_dims
,
num_coordinates
=
shape
...
@@ -702,12 +702,12 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
...
@@ -702,12 +702,12 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
format
=
format
,
canvas_size
=
spatial_size
,
batch_dims
=
batch_dims
,
dtype
=
dtype
,
device
=
device
format
=
format
,
canvas_size
=
spatial_size
,
batch_dims
=
batch_dims
,
dtype
=
dtype
,
device
=
device
)
)
return
BoundingBoxesLoader
(
fn
,
shape
=
(
*
extra_dims
,
4
),
dtype
=
dtype
,
format
=
format
,
spatial_size
=
spatial_size
)
return
BoundingBoxesLoader
(
fn
,
shape
=
(
*
extra_dims
[
-
1
:]
,
4
),
dtype
=
dtype
,
format
=
format
,
spatial_size
=
spatial_size
)
def
make_bounding_box_loaders
(
def
make_bounding_box_loaders
(
*
,
*
,
extra_dims
=
DEFAULT_EXTRA_DIMS
,
extra_dims
=
tuple
(
d
for
d
in
DEFAULT_EXTRA_DIMS
if
len
(
d
)
<
2
)
,
formats
=
tuple
(
datapoints
.
BoundingBoxFormat
),
formats
=
tuple
(
datapoints
.
BoundingBoxFormat
),
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
spatial_size
=
DEFAULT_PORTRAIT_SPATIAL_SIZE
,
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
dtypes
=
(
torch
.
float32
,
torch
.
float64
,
torch
.
int64
),
...
...
test/test_datapoints.py
View file @
8faa1b14
...
@@ -22,7 +22,7 @@ def test_mask_instance(data):
...
@@ -22,7 +22,7 @@ def test_mask_instance(data):
assert
mask
.
ndim
==
3
and
mask
.
shape
[
0
]
==
1
assert
mask
.
ndim
==
3
and
mask
.
shape
[
0
]
==
1
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
32
,
size
=
(
5
,
4
)),
[[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]]])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
randint
(
0
,
32
,
size
=
(
5
,
4
)),
[[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]]
,
[
1
,
2
,
3
,
4
]
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"format"
,
[
"XYXY"
,
"CXCYWH"
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
datapoints
.
BoundingBoxFormat
.
XYWH
]
"format"
,
[
"XYXY"
,
"CXCYWH"
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
datapoints
.
BoundingBoxFormat
.
XYWH
]
)
)
...
@@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
...
@@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
assert
bboxes
.
format
==
format
assert
bboxes
.
format
==
format
def
test_bbox_dim_error
():
data_3d
=
[[[
1
,
2
,
3
,
4
]]]
with
pytest
.
raises
(
ValueError
,
match
=
"Expected a 1D or 2D tensor, got 3D"
):
datapoints
.
BoundingBoxes
(
data_3d
,
format
=
"XYXY"
,
canvas_size
=
(
32
,
32
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"data"
,
"input_requires_grad"
,
"expected_requires_grad"
),
(
"data"
,
"input_requires_grad"
,
"expected_requires_grad"
),
[
[
...
...
test/test_prototype_transforms.py
View file @
8faa1b14
...
@@ -20,7 +20,7 @@ from prototype_common_utils import make_label
...
@@ -20,7 +20,7 @@ from prototype_common_utils import make_label
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.datapoints
import
BoundingBoxes
,
BoundingBoxFormat
,
Image
,
Mask
,
Video
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.prototype
import
datapoints
,
transforms
from
torchvision.transforms.v2._utils
import
_convert_fill_arg
from
torchvision.transforms.v2._utils
import
_convert_fill_arg
from
torchvision.transforms.v2.functional
import
InterpolationMode
,
pil_to_tensor
,
to_image_pil
from
torchvision.transforms.v2.functional
import
clamp_bounding_boxes
,
InterpolationMode
,
pil_to_tensor
,
to_image_pil
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
from
torchvision.transforms.v2.utils
import
check_type
,
is_simple_tensor
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
BATCH_EXTRA_DIMS
=
[
extra_dims
for
extra_dims
in
DEFAULT_EXTRA_DIMS
if
extra_dims
]
...
@@ -306,7 +306,9 @@ class TestFixedSizeCrop:
...
@@ -306,7 +306,9 @@ class TestFixedSizeCrop:
bounding_boxes
=
make_bounding_box
(
bounding_boxes
=
make_bounding_box
(
format
=
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
batch_dims
=
(
batch_size
,)
format
=
BoundingBoxFormat
.
XYXY
,
canvas_size
=
canvas_size
,
batch_dims
=
(
batch_size
,)
)
)
mock
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes"
)
mock
=
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes"
,
wraps
=
clamp_bounding_boxes
)
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
))
transform
=
transforms
.
FixedSizeCrop
((
-
1
,
-
1
))
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
mocker
.
patch
(
"torchvision.prototype.transforms._geometry.has_any"
,
return_value
=
True
)
...
...
test/test_transforms_v2.py
View file @
8faa1b14
...
@@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors():
...
@@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors():
different_sizes
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]
+
3
)}
different_sizes
=
{
"bbox"
:
good_bbox
,
"labels"
:
torch
.
arange
(
good_bbox
.
shape
[
0
]
+
3
)}
transforms
.
SanitizeBoundingBoxes
()(
different_sizes
)
transforms
.
SanitizeBoundingBoxes
()(
different_sizes
)
with
pytest
.
raises
(
ValueError
,
match
=
"boxes must be of shape"
):
bad_bbox
=
datapoints
.
BoundingBoxes
(
# batch with 2 elements
[
[[
0
,
0
,
10
,
10
]],
[[
0
,
0
,
10
,
10
]],
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
20
,
20
),
)
different_sizes
=
{
"bbox"
:
bad_bbox
,
"labels"
:
torch
.
arange
(
bad_bbox
.
shape
[
0
])}
transforms
.
SanitizeBoundingBoxes
()(
different_sizes
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"import_statement"
,
"import_statement"
,
...
...
test/test_transforms_v2_functional.py
View file @
8faa1b14
...
@@ -711,21 +711,20 @@ def _parse_padding(padding):
...
@@ -711,21 +711,20 @@ def _parse_padding(padding):
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_cuda
())
@
pytest
.
mark
.
parametrize
(
"padding"
,
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[[
1
],
[
1
,
1
],
[
1
,
1
,
2
,
2
]])
def
test_correctness_pad_bounding_boxes
(
device
,
padding
):
def
test_correctness_pad_bounding_boxes
(
device
,
padding
):
def
_compute_expected_bbox
(
bbox
,
padding_
):
def
_compute_expected_bbox
(
bbox
,
format
,
padding_
):
pad_left
,
pad_up
,
_
,
_
=
_parse_padding
(
padding_
)
pad_left
,
pad_up
,
_
,
_
=
_parse_padding
(
padding_
)
dtype
=
bbox
.
dtype
dtype
=
bbox
.
dtype
format
=
bbox
.
format
bbox
=
(
bbox
=
(
bbox
.
clone
()
bbox
.
clone
()
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
if
format
==
datapoints
.
BoundingBoxFormat
.
XYXY
else
convert_format_bounding_boxes
(
bbox
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
else
convert_format_bounding_boxes
(
bbox
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
)
)
bbox
[
0
::
2
]
+=
pad_left
bbox
[
0
::
2
]
+=
pad_left
bbox
[
1
::
2
]
+=
pad_up
bbox
[
1
::
2
]
+=
pad_up
bbox
=
convert_format_bounding_boxes
(
bbox
,
new_format
=
format
)
bbox
=
convert_format_bounding_boxes
(
bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
)
if
bbox
.
dtype
!=
dtype
:
if
bbox
.
dtype
!=
dtype
:
# Temporary cast to original dtype
# Temporary cast to original dtype
# e.g. float32 -> int
# e.g. float32 -> int
...
@@ -737,7 +736,7 @@ def test_correctness_pad_bounding_boxes(device, padding):
...
@@ -737,7 +736,7 @@ def test_correctness_pad_bounding_boxes(device, padding):
height
,
width
=
bbox
.
canvas_size
height
,
width
=
bbox
.
canvas_size
return
height
+
pad_up
+
pad_down
,
width
+
pad_left
+
pad_right
return
height
+
pad_up
+
pad_down
,
width
+
pad_left
+
pad_right
for
bboxes
in
make_bounding_boxes
():
for
bboxes
in
make_bounding_boxes
(
extra_dims
=
((
4
,),)
):
bboxes
=
bboxes
.
to
(
device
)
bboxes
=
bboxes
.
to
(
device
)
bboxes_format
=
bboxes
.
format
bboxes_format
=
bboxes
.
format
bboxes_canvas_size
=
bboxes
.
canvas_size
bboxes_canvas_size
=
bboxes
.
canvas_size
...
@@ -748,18 +747,10 @@ def test_correctness_pad_bounding_boxes(device, padding):
...
@@ -748,18 +747,10 @@ def test_correctness_pad_bounding_boxes(device, padding):
torch
.
testing
.
assert_close
(
output_canvas_size
,
_compute_expected_canvas_size
(
bboxes
,
padding
))
torch
.
testing
.
assert_close
(
output_canvas_size
,
_compute_expected_canvas_size
(
bboxes
,
padding
))
if
bboxes
.
ndim
<
2
or
bboxes
.
shape
[
0
]
==
0
:
expected_bboxes
=
torch
.
stack
(
bboxes
=
[
bboxes
]
[
_compute_expected_bbox
(
b
,
bboxes_format
,
padding
)
for
b
in
bboxes
.
reshape
(
-
1
,
4
).
unbind
()]
).
reshape
(
bboxes
.
shape
)
expected_bboxes
=
[]
for
bbox
in
bboxes
:
bbox
=
datapoints
.
BoundingBoxes
(
bbox
,
format
=
bboxes_format
,
canvas_size
=
bboxes_canvas_size
)
expected_bboxes
.
append
(
_compute_expected_bbox
(
bbox
,
padding
))
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
expected_bboxes
=
expected_bboxes
[
0
]
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
,
atol
=
1
,
rtol
=
0
)
...
@@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
...
@@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
],
],
)
)
def
test_correctness_perspective_bounding_boxes
(
device
,
startpoints
,
endpoints
):
def
test_correctness_perspective_bounding_boxes
(
device
,
startpoints
,
endpoints
):
def
_compute_expected_bbox
(
bbox
,
pcoeffs_
):
def
_compute_expected_bbox
(
bbox
,
format_
,
canvas_size_
,
pcoeffs_
):
m1
=
np
.
array
(
m1
=
np
.
array
(
[
[
[
pcoeffs_
[
0
],
pcoeffs_
[
1
],
pcoeffs_
[
2
]],
[
pcoeffs_
[
0
],
pcoeffs_
[
1
],
pcoeffs_
[
2
]],
...
@@ -798,7 +789,9 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -798,7 +789,9 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
]
)
)
bbox_xyxy
=
convert_format_bounding_boxes
(
bbox
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
bbox_xyxy
=
convert_format_bounding_boxes
(
bbox
,
old_format
=
format_
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
points
=
np
.
array
(
points
=
np
.
array
(
[
[
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
[
bbox_xyxy
[
0
].
item
(),
bbox_xyxy
[
1
].
item
(),
1.0
],
...
@@ -818,14 +811,11 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -818,14 +811,11 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
np
.
max
(
transformed_points
[:,
1
]),
np
.
max
(
transformed_points
[:,
1
]),
]
]
)
)
out_bbox
=
datapoints
.
BoundingBoxes
(
out_bbox
=
torch
.
from_numpy
(
out_bbox
)
out_bbox
,
out_bbox
=
convert_format_bounding_boxes
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
out_bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
canvas_size
=
bbox
.
canvas_size
,
dtype
=
bbox
.
dtype
,
device
=
bbox
.
device
,
)
)
return
clamp_bounding_boxes
(
convert_format_bounding_boxes
(
out_bbox
,
new_
format
=
bbox
.
format
)
)
return
clamp_bounding_boxes
(
out_bbox
,
format
=
format
_
,
canvas_size
=
canvas_size_
).
to
(
bbox
)
canvas_size
=
(
32
,
38
)
canvas_size
=
(
32
,
38
)
...
@@ -844,17 +834,13 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -844,17 +834,13 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
coefficients
=
pcoeffs
,
coefficients
=
pcoeffs
,
)
)
if
bboxes
.
ndim
<
2
:
expected_bboxes
=
torch
.
stack
(
bboxes
=
[
bboxes
]
[
_compute_expected_bbox
(
b
,
bboxes
.
format
,
bboxes
.
canvas_size
,
inv_pcoeffs
)
for
b
in
bboxes
.
reshape
(
-
1
,
4
).
unbind
()
]
).
reshape
(
bboxes
.
shape
)
expected_bboxes
=
[]
for
bbox
in
bboxes
:
bbox
=
datapoints
.
BoundingBoxes
(
bbox
,
format
=
bboxes
.
format
,
canvas_size
=
bboxes
.
canvas_size
)
expected_bboxes
.
append
(
_compute_expected_bbox
(
bbox
,
inv_pcoeffs
))
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
expected_bboxes
=
expected_bboxes
[
0
]
torch
.
testing
.
assert_close
(
output_bboxes
,
expected_bboxes
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_close
(
output_bboxes
,
expected_bboxes
,
rtol
=
0
,
atol
=
1
)
...
@@ -864,9 +850,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
...
@@ -864,9 +850,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
[(
18
,
18
),
[
18
,
15
],
(
16
,
19
),
[
12
],
[
46
,
48
]],
[(
18
,
18
),
[
18
,
15
],
(
16
,
19
),
[
12
],
[
46
,
48
]],
)
)
def
test_correctness_center_crop_bounding_boxes
(
device
,
output_size
):
def
test_correctness_center_crop_bounding_boxes
(
device
,
output_size
):
def
_compute_expected_bbox
(
bbox
,
output_size_
):
def
_compute_expected_bbox
(
bbox
,
format_
,
canvas_size_
,
output_size_
):
format_
=
bbox
.
format
canvas_size_
=
bbox
.
canvas_size
dtype
=
bbox
.
dtype
dtype
=
bbox
.
dtype
bbox
=
convert_format_bounding_boxes
(
bbox
.
float
(),
format_
,
datapoints
.
BoundingBoxFormat
.
XYWH
)
bbox
=
convert_format_bounding_boxes
(
bbox
.
float
(),
format_
,
datapoints
.
BoundingBoxFormat
.
XYWH
)
...
@@ -895,18 +879,12 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
...
@@ -895,18 +879,12 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bboxes
,
bboxes_format
,
bboxes_canvas_size
,
output_size
bboxes
,
bboxes_format
,
bboxes_canvas_size
,
output_size
)
)
if
bboxes
.
ndim
<
2
:
expected_bboxes
=
torch
.
stack
(
bboxes
=
[
bboxes
]
[
_compute_expected_bbox
(
b
,
bboxes_format
,
bboxes_canvas_size
,
output_size
)
expected_bboxes
=
[]
for
b
in
bboxes
.
reshape
(
-
1
,
4
).
unbind
()
for
bbox
in
bboxes
:
]
bbox
=
datapoints
.
BoundingBoxes
(
bbox
,
format
=
bboxes_format
,
canvas_size
=
bboxes_canvas_size
)
).
reshape
(
bboxes
.
shape
)
expected_bboxes
.
append
(
_compute_expected_bbox
(
bbox
,
output_size
))
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
expected_bboxes
=
expected_bboxes
[
0
]
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
output_size
)
torch
.
testing
.
assert_close
(
output_canvas_size
,
output_size
)
...
...
test/transforms_v2_kernel_infos.py
View file @
8faa1b14
...
@@ -222,16 +222,9 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
...
@@ -222,16 +222,9 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
out_bbox
=
out_bbox
.
to
(
dtype
=
in_dtype
)
out_bbox
=
out_bbox
.
to
(
dtype
=
in_dtype
)
return
out_bbox
return
out_bbox
if
bounding_boxes
.
ndim
<
2
:
return
torch
.
stack
(
bounding_boxes
=
[
bounding_boxes
]
[
transform
(
b
,
affine_matrix
,
format
,
canvas_size
)
for
b
in
bounding_boxes
.
reshape
(
-
1
,
4
).
unbind
()]
).
reshape
(
bounding_boxes
.
shape
)
expected_bboxes
=
[
transform
(
bbox
,
affine_matrix
,
format
,
canvas_size
)
for
bbox
in
bounding_boxes
]
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
expected_bboxes
=
expected_bboxes
[
0
]
return
expected_bboxes
def
sample_inputs_convert_format_bounding_boxes
():
def
sample_inputs_convert_format_bounding_boxes
():
...
...
torchvision/datapoints/_bounding_box.py
View file @
8faa1b14
...
@@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum):
...
@@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum):
class
BoundingBoxes
(
Datapoint
):
class
BoundingBoxes
(
Datapoint
):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes.
.. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
contain multiple bounding boxes.
Args:
Args:
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
format (BoundingBoxFormat, str): Format of the bounding box.
format (BoundingBoxFormat, str): Format of the bounding box.
...
@@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint):
...
@@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint):
@
classmethod
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
Union
[
BoundingBoxFormat
,
str
],
canvas_size
:
Tuple
[
int
,
int
])
->
BoundingBoxes
:
# type: ignore[override]
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
Union
[
BoundingBoxFormat
,
str
],
canvas_size
:
Tuple
[
int
,
int
])
->
BoundingBoxes
:
# type: ignore[override]
if
tensor
.
ndim
==
1
:
tensor
=
tensor
.
unsqueeze
(
0
)
elif
tensor
.
ndim
!=
2
:
raise
ValueError
(
f
"Expected a 1D or 2D tensor, got
{
tensor
.
ndim
}
D"
)
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
BoundingBoxFormat
[
format
.
upper
()]
format
=
BoundingBoxFormat
[
format
.
upper
()]
bounding_boxes
=
tensor
.
as_subclass
(
cls
)
bounding_boxes
=
tensor
.
as_subclass
(
cls
)
...
...
torchvision/prototype/transforms/_geometry.py
View file @
8faa1b14
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
...
@@ -7,7 +7,7 @@ from torchvision import datapoints
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.prototype.datapoints
import
Label
,
OneHotLabel
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2._utils
import
_get_fill
,
_setup_fill_arg
,
_setup_size
from
torchvision.transforms.v2.utils
import
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_size
from
torchvision.transforms.v2.utils
import
get_bounding_boxes
,
has_any
,
is_simple_tensor
,
query_size
class
FixedSizeCrop
(
Transform
):
class
FixedSizeCrop
(
Transform
):
...
@@ -61,7 +61,7 @@ class FixedSizeCrop(Transform):
...
@@ -61,7 +61,7 @@ class FixedSizeCrop(Transform):
bounding_boxes
:
Optional
[
torch
.
Tensor
]
bounding_boxes
:
Optional
[
torch
.
Tensor
]
try
:
try
:
bounding_boxes
=
query
_bounding_boxes
(
flat_inputs
)
bounding_boxes
=
get
_bounding_boxes
(
flat_inputs
)
except
ValueError
:
except
ValueError
:
bounding_boxes
=
None
bounding_boxes
=
None
...
...
torchvision/transforms/v2/_geometry.py
View file @
8faa1b14
...
@@ -23,7 +23,7 @@ from ._utils import (
...
@@ -23,7 +23,7 @@ from ._utils import (
_setup_float_or_seq
,
_setup_float_or_seq
,
_setup_size
,
_setup_size
,
)
)
from
.utils
import
has_all
,
has_any
,
is_simple_tensor
,
query_bounding_boxes
,
query_size
from
.utils
import
get_bounding_boxes
,
has_all
,
has_any
,
is_simple_tensor
,
query_size
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
class
RandomHorizontalFlip
(
_RandomApplyTransform
):
...
@@ -1137,7 +1137,7 @@ class RandomIoUCrop(Transform):
...
@@ -1137,7 +1137,7 @@ class RandomIoUCrop(Transform):
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
def
_get_params
(
self
,
flat_inputs
:
List
[
Any
])
->
Dict
[
str
,
Any
]:
orig_h
,
orig_w
=
query_size
(
flat_inputs
)
orig_h
,
orig_w
=
query_size
(
flat_inputs
)
bboxes
=
query
_bounding_boxes
(
flat_inputs
)
bboxes
=
get
_bounding_boxes
(
flat_inputs
)
while
True
:
while
True
:
# sample an option
# sample an option
...
...
torchvision/transforms/v2/_misc.py
View file @
8faa1b14
...
@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
...
@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
torchvision.transforms.v2
import
functional
as
F
,
Transform
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
._utils
import
_parse_labels_getter
,
_setup_float_or_seq
,
_setup_size
from
.utils
import
has_any
,
is_simple_tensor
,
query_bounding_boxes
from
.utils
import
get_bounding_boxes
,
has_any
,
is_simple_tensor
# TODO: do we want/need to expose this?
# TODO: do we want/need to expose this?
...
@@ -384,13 +384,7 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -384,13 +384,7 @@ class SanitizeBoundingBoxes(Transform):
)
)
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
flat_inputs
,
spec
=
tree_flatten
(
inputs
)
# TODO: this enforces one single BoundingBoxes entry.
boxes
=
get_bounding_boxes
(
flat_inputs
)
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes
=
query_bounding_boxes
(
flat_inputs
)
if
boxes
.
ndim
!=
2
:
raise
ValueError
(
f
"boxes must be of shape (num_boxes, 4), got
{
boxes
.
shape
}
"
)
if
labels
is
not
None
and
boxes
.
shape
[
0
]
!=
labels
.
shape
[
0
]:
if
labels
is
not
None
and
boxes
.
shape
[
0
]
!=
labels
.
shape
[
0
]:
raise
ValueError
(
raise
ValueError
(
...
...
torchvision/transforms/v2/utils.py
View file @
8faa1b14
...
@@ -9,13 +9,12 @@ from torchvision._utils import sequence_to_str
...
@@ -9,13 +9,12 @@ from torchvision._utils import sequence_to_str
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_simple_tensor
from
torchvision.transforms.v2.functional
import
get_dimensions
,
get_size
,
is_simple_tensor
def
query_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
def
get_bounding_boxes
(
flat_inputs
:
List
[
Any
])
->
datapoints
.
BoundingBoxes
:
bounding_boxes
=
[
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
)]
# This assumes there is only one bbox per sample as per the general convention
if
not
bounding_boxes
:
try
:
raise
TypeError
(
"No bounding boxes were found in the sample"
)
return
next
(
inpt
for
inpt
in
flat_inputs
if
isinstance
(
inpt
,
datapoints
.
BoundingBoxes
))
elif
len
(
bounding_boxes
)
>
1
:
except
StopIteration
:
raise
ValueError
(
"Found multiple bounding boxes instances in the sample"
)
raise
ValueError
(
"No bounding boxes were found in the sample"
)
return
bounding_boxes
.
pop
()
def
query_chw
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
,
int
]:
def
query_chw
(
flat_inputs
:
List
[
Any
])
->
Tuple
[
int
,
int
,
int
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment