Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
602e8ca1
Unverified
Commit
602e8ca1
authored
Feb 14, 2023
by
Philip Meier
Committed by
GitHub
Feb 14, 2023
Browse files
clamp bounding boxes in some geometry kernels (#7215)
Co-authored-by:
vfdev-5
<
vfdev.5@gmail.com
>
parent
6af6bf45
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
189 additions
and
89 deletions
+189
-89
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+93
-33
test/test_prototype_transforms.py
test/test_prototype_transforms.py
+1
-1
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+23
-7
torchvision/prototype/datapoints/_bounding_box.py
torchvision/prototype/datapoints/_bounding_box.py
+8
-2
torchvision/prototype/transforms/__init__.py
torchvision/prototype/transforms/__init__.py
+1
-1
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+1
-1
torchvision/prototype/transforms/functional/_geometry.py
torchvision/prototype/transforms/functional/_geometry.py
+55
-42
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+7
-2
No files found.
test/prototype_transforms_kernel_infos.py
View file @
602e8ca1
...
@@ -108,7 +108,7 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
...
@@ -108,7 +108,7 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
}
}
def
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
device
,
atol
=
1e-6
,
rtol
=
1e-6
):
def
scripted_vs_eager_
float64_tol
er
a
nce
s
(
device
,
atol
=
1e-6
,
rtol
=
1e-6
):
return
{
return
{
((
"TestKernels"
,
"test_scripted_vs_eager"
),
torch
.
float64
,
device
):
{
"atol"
:
atol
,
"rtol"
:
rtol
,
"mae"
:
False
},
((
"TestKernels"
,
"test_scripted_vs_eager"
),
torch
.
float64
,
device
):
{
"atol"
:
atol
,
"rtol"
:
rtol
,
"mae"
:
False
},
}
}
...
@@ -211,10 +211,12 @@ def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size
...
@@ -211,10 +211,12 @@ def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size
[
-
1
,
0
,
spatial_size
[
1
]],
[
-
1
,
0
,
spatial_size
[
1
]],
[
0
,
1
,
0
],
[
0
,
1
,
0
],
],
],
dtype
=
"float32"
,
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
)
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
affine_matrix
=
affine_matrix
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
)
return
expected_bboxes
return
expected_bboxes
...
@@ -322,7 +324,7 @@ def reference_inputs_resize_image_tensor():
...
@@ -322,7 +324,7 @@ def reference_inputs_resize_image_tensor():
def
sample_inputs_resize_bounding_box
():
def
sample_inputs_resize_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
for
size
in
_get_resize_sizes
(
bounding_box_loader
.
spatial_size
):
for
size
in
_get_resize_sizes
(
bounding_box_loader
.
spatial_size
):
yield
ArgsKwargs
(
bounding_box_loader
,
size
=
size
,
spatial_size
=
bounding_box_loader
.
spatial_size
)
yield
ArgsKwargs
(
bounding_box_loader
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
size
=
size
)
def
sample_inputs_resize_mask
():
def
sample_inputs_resize_mask
():
...
@@ -344,19 +346,20 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
...
@@ -344,19 +346,20 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
[
new_width
/
old_width
,
0
,
0
],
[
new_width
/
old_width
,
0
,
0
],
[
0
,
new_height
/
old_height
,
0
],
[
0
,
new_height
/
old_height
,
0
],
],
],
dtype
=
"float32"
,
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
)
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
affine_matrix
=
affine_matrix
bounding_box
,
format
=
bounding_box
.
format
,
spatial_size
=
(
new_height
,
new_width
),
affine_matrix
=
affine_matrix
,
)
)
return
expected_bboxes
,
(
new_height
,
new_width
)
return
expected_bboxes
,
(
new_height
,
new_width
)
def
reference_inputs_resize_bounding_box
():
def
reference_inputs_resize_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
(
for
bounding_box_loader
in
make_bounding_box_loaders
(
extra_dims
=
((),
(
4
,))):
formats
=
[
datapoints
.
BoundingBoxFormat
.
XYXY
],
extra_dims
=
((),
(
4
,))
):
for
size
in
_get_resize_sizes
(
bounding_box_loader
.
spatial_size
):
for
size
in
_get_resize_sizes
(
bounding_box_loader
.
spatial_size
):
yield
ArgsKwargs
(
bounding_box_loader
,
size
=
size
,
spatial_size
=
bounding_box_loader
.
spatial_size
)
yield
ArgsKwargs
(
bounding_box_loader
,
size
=
size
,
spatial_size
=
bounding_box_loader
.
spatial_size
)
...
@@ -543,14 +546,17 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
...
@@ -543,14 +546,17 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return
true_matrix
return
true_matrix
def
reference_affine_bounding_box_helper
(
bounding_box
,
*
,
format
,
affine_matrix
):
def
reference_affine_bounding_box_helper
(
bounding_box
,
*
,
format
,
spatial_size
,
affine_matrix
):
def
transform
(
bbox
,
affine_matrix_
,
format_
):
def
transform
(
bbox
,
affine_matrix_
,
format_
,
spatial_size_
):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype
=
bbox
.
dtype
in_dtype
=
bbox
.
dtype
if
not
torch
.
is_floating_point
(
bbox
):
if
not
torch
.
is_floating_point
(
bbox
):
bbox
=
bbox
.
float
()
bbox
=
bbox
.
float
()
bbox_xyxy
=
F
.
convert_format_bounding_box
(
bbox_xyxy
=
F
.
convert_format_bounding_box
(
bbox
,
old_format
=
format_
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bbox
.
as_subclass
(
torch
.
Tensor
),
old_format
=
format_
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
,
)
)
points
=
np
.
array
(
points
=
np
.
array
(
[
[
...
@@ -573,12 +579,15 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
...
@@ -573,12 +579,15 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
out_bbox
=
F
.
convert_format_bounding_box
(
out_bbox
=
F
.
convert_format_bounding_box
(
out_bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
out_bbox
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format_
,
inplace
=
True
)
)
return
out_bbox
.
to
(
dtype
=
in_dtype
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox
=
F
.
clamp_bounding_box
(
out_bbox
,
format
=
format_
,
spatial_size
=
spatial_size_
)
out_bbox
=
out_bbox
.
to
(
dtype
=
in_dtype
)
return
out_bbox
if
bounding_box
.
ndim
<
2
:
if
bounding_box
.
ndim
<
2
:
bounding_box
=
[
bounding_box
]
bounding_box
=
[
bounding_box
]
expected_bboxes
=
[
transform
(
bbox
,
affine_matrix
,
format
)
for
bbox
in
bounding_box
]
expected_bboxes
=
[
transform
(
bbox
,
affine_matrix
,
format
,
spatial_size
)
for
bbox
in
bounding_box
]
if
len
(
expected_bboxes
)
>
1
:
if
len
(
expected_bboxes
)
>
1
:
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
else
:
...
@@ -594,7 +603,9 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
...
@@ -594,7 +603,9 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
affine_matrix
=
_compute_affine_matrix
(
angle
,
translate
,
scale
,
shear
,
center
)
affine_matrix
=
_compute_affine_matrix
(
angle
,
translate
,
scale
,
shear
,
center
)
affine_matrix
=
affine_matrix
[:
2
,
:]
affine_matrix
=
affine_matrix
[:
2
,
:]
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
affine_matrix
=
affine_matrix
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
)
return
expected_bboxes
return
expected_bboxes
...
@@ -643,9 +654,6 @@ KERNEL_INFOS.extend(
...
@@ -643,9 +654,6 @@ KERNEL_INFOS.extend(
sample_inputs_fn
=
sample_inputs_affine_bounding_box
,
sample_inputs_fn
=
sample_inputs_affine_bounding_box
,
reference_fn
=
reference_affine_bounding_box
,
reference_fn
=
reference_affine_bounding_box
,
reference_inputs_fn
=
reference_inputs_affine_bounding_box
,
reference_inputs_fn
=
reference_inputs_affine_bounding_box
,
closeness_kwargs
=
{
((
"TestKernels"
,
"test_against_reference"
),
torch
.
int64
,
"cpu"
):
dict
(
atol
=
1
,
rtol
=
0
),
},
test_marks
=
[
test_marks
=
[
xfail_jit_python_scalar_arg
(
"shear"
),
xfail_jit_python_scalar_arg
(
"shear"
),
],
],
...
@@ -729,10 +737,12 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
...
@@ -729,10 +737,12 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
[
1
,
0
,
0
],
[
1
,
0
,
0
],
[
0
,
-
1
,
spatial_size
[
0
]],
[
0
,
-
1
,
spatial_size
[
0
]],
],
],
dtype
=
"float32"
,
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
)
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
affine_matrix
=
affine_matrix
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
)
return
expected_bboxes
return
expected_bboxes
...
@@ -806,6 +816,43 @@ def sample_inputs_rotate_bounding_box():
...
@@ -806,6 +816,43 @@ def sample_inputs_rotate_bounding_box():
)
)
def
reference_inputs_rotate_bounding_box
():
for
bounding_box_loader
,
angle
in
itertools
.
product
(
make_bounding_box_loaders
(
extra_dims
=
((),
(
4
,))),
_ROTATE_ANGLES
):
yield
ArgsKwargs
(
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
angle
=
angle
,
)
# TODO: add samples with expand=True and center
def
reference_rotate_bounding_box
(
bounding_box
,
*
,
format
,
spatial_size
,
angle
,
expand
=
False
,
center
=
None
):
if
center
is
None
:
center
=
[
spatial_size
[
1
]
*
0.5
,
spatial_size
[
0
]
*
0.5
]
a
=
np
.
cos
(
angle
*
np
.
pi
/
180.0
)
b
=
np
.
sin
(
angle
*
np
.
pi
/
180.0
)
cx
=
center
[
0
]
cy
=
center
[
1
]
affine_matrix
=
np
.
array
(
[
[
a
,
b
,
cx
-
cx
*
a
-
b
*
cy
],
[
-
b
,
a
,
cy
+
cx
*
b
-
a
*
cy
],
],
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
)
return
expected_bboxes
,
spatial_size
def
sample_inputs_rotate_mask
():
def
sample_inputs_rotate_mask
():
for
mask_loader
in
make_mask_loaders
(
sizes
=
[
"random"
],
num_categories
=
[
"random"
],
num_objects
=
[
"random"
]):
for
mask_loader
in
make_mask_loaders
(
sizes
=
[
"random"
],
num_categories
=
[
"random"
],
num_objects
=
[
"random"
]):
yield
ArgsKwargs
(
mask_loader
,
angle
=
15.0
)
yield
ArgsKwargs
(
mask_loader
,
angle
=
15.0
)
...
@@ -834,9 +881,11 @@ KERNEL_INFOS.extend(
...
@@ -834,9 +881,11 @@ KERNEL_INFOS.extend(
KernelInfo
(
KernelInfo
(
F
.
rotate_bounding_box
,
F
.
rotate_bounding_box
,
sample_inputs_fn
=
sample_inputs_rotate_bounding_box
,
sample_inputs_fn
=
sample_inputs_rotate_bounding_box
,
reference_fn
=
reference_rotate_bounding_box
,
reference_inputs_fn
=
reference_inputs_rotate_bounding_box
,
closeness_kwargs
=
{
closeness_kwargs
=
{
**
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
"cpu"
,
atol
=
1e-
5
,
rtol
=
1e-
5
),
**
scripted_vs_eager_
float64_tol
er
a
nce
s
(
"cpu"
,
atol
=
1e-
6
,
rtol
=
1e-
6
),
**
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
float64_tol
er
a
nce
s
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
},
},
),
),
KernelInfo
(
KernelInfo
(
...
@@ -897,17 +946,19 @@ def sample_inputs_crop_video():
...
@@ -897,17 +946,19 @@ def sample_inputs_crop_video():
def
reference_crop_bounding_box
(
bounding_box
,
*
,
format
,
top
,
left
,
height
,
width
):
def
reference_crop_bounding_box
(
bounding_box
,
*
,
format
,
top
,
left
,
height
,
width
):
affine_matrix
=
np
.
array
(
affine_matrix
=
np
.
array
(
[
[
[
1
,
0
,
-
left
],
[
1
,
0
,
-
left
],
[
0
,
1
,
-
top
],
[
0
,
1
,
-
top
],
],
],
dtype
=
"float32"
,
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
)
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
affine_matrix
=
affine_matrix
)
spatial_size
=
(
height
,
width
)
return
expected_bboxes
,
(
height
,
width
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
,
affine_matrix
=
affine_matrix
)
return
expected_bboxes
,
spatial_size
def
reference_inputs_crop_bounding_box
():
def
reference_inputs_crop_bounding_box
():
...
@@ -1119,13 +1170,15 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p
...
@@ -1119,13 +1170,15 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p
[
1
,
0
,
left
],
[
1
,
0
,
left
],
[
0
,
1
,
top
],
[
0
,
1
,
top
],
],
],
dtype
=
"float32"
,
dtype
=
"float64"
if
bounding_box
.
dtype
==
torch
.
float64
else
"float32"
,
)
)
height
=
spatial_size
[
0
]
+
top
+
bottom
height
=
spatial_size
[
0
]
+
top
+
bottom
width
=
spatial_size
[
1
]
+
left
+
right
width
=
spatial_size
[
1
]
+
left
+
right
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
affine_matrix
=
affine_matrix
)
expected_bboxes
=
reference_affine_bounding_box_helper
(
bounding_box
,
format
=
format
,
spatial_size
=
(
height
,
width
),
affine_matrix
=
affine_matrix
)
return
expected_bboxes
,
(
height
,
width
)
return
expected_bboxes
,
(
height
,
width
)
...
@@ -1225,14 +1278,16 @@ def sample_inputs_perspective_bounding_box():
...
@@ -1225,14 +1278,16 @@ def sample_inputs_perspective_bounding_box():
yield
ArgsKwargs
(
yield
ArgsKwargs
(
bounding_box_loader
,
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
startpoints
=
None
,
startpoints
=
None
,
endpoints
=
None
,
endpoints
=
None
,
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
coefficients
=
_PERSPECTIVE_COEFFS
[
0
],
)
)
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
loader
=
make_bounding_box_loader
(
format
=
format
)
yield
ArgsKwargs
(
yield
ArgsKwargs
(
make_bounding_box_
loader
(
format
=
format
)
,
format
=
format
,
startpoints
=
_STARTPOINTS
,
endpoints
=
_ENDPOINTS
loader
,
format
=
format
,
spatial_size
=
loader
.
spatial_size
,
startpoints
=
_STARTPOINTS
,
endpoints
=
_ENDPOINTS
)
)
...
@@ -1269,13 +1324,17 @@ KERNEL_INFOS.extend(
...
@@ -1269,13 +1324,17 @@ KERNEL_INFOS.extend(
**
pil_reference_pixel_difference
(
2
,
mae
=
True
),
**
pil_reference_pixel_difference
(
2
,
mae
=
True
),
**
cuda_vs_cpu_pixel_difference
(),
**
cuda_vs_cpu_pixel_difference
(),
**
float32_vs_uint8_pixel_difference
(),
**
float32_vs_uint8_pixel_difference
(),
**
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
"cpu"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
float64_tol
er
a
nce
s
(
"cpu"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
float64_tol
er
a
nce
s
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
},
},
),
),
KernelInfo
(
KernelInfo
(
F
.
perspective_bounding_box
,
F
.
perspective_bounding_box
,
sample_inputs_fn
=
sample_inputs_perspective_bounding_box
,
sample_inputs_fn
=
sample_inputs_perspective_bounding_box
,
closeness_kwargs
=
{
**
scripted_vs_eager_float64_tolerances
(
"cpu"
,
atol
=
1e-6
,
rtol
=
1e-6
),
**
scripted_vs_eager_float64_tolerances
(
"cuda"
,
atol
=
1e-6
,
rtol
=
1e-6
),
},
),
),
KernelInfo
(
KernelInfo
(
F
.
perspective_mask
,
F
.
perspective_mask
,
...
@@ -1292,8 +1351,8 @@ KERNEL_INFOS.extend(
...
@@ -1292,8 +1351,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn
=
sample_inputs_perspective_video
,
sample_inputs_fn
=
sample_inputs_perspective_video
,
closeness_kwargs
=
{
closeness_kwargs
=
{
**
cuda_vs_cpu_pixel_difference
(),
**
cuda_vs_cpu_pixel_difference
(),
**
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
"cpu"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
float64_tol
er
a
nce
s
(
"cpu"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
double_pixel_diff
er
e
nce
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
**
scripted_vs_eager_
float64_tol
er
a
nce
s
(
"cuda"
,
atol
=
1e-5
,
rtol
=
1e-5
),
},
},
),
),
]
]
...
@@ -1331,6 +1390,7 @@ def sample_inputs_elastic_bounding_box():
...
@@ -1331,6 +1390,7 @@ def sample_inputs_elastic_bounding_box():
yield
ArgsKwargs
(
yield
ArgsKwargs
(
bounding_box_loader
,
bounding_box_loader
,
format
=
bounding_box_loader
.
format
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
displacement
=
displacement
,
displacement
=
displacement
,
)
)
...
...
test/test_prototype_transforms.py
View file @
602e8ca1
...
@@ -146,7 +146,7 @@ class TestSmoke:
...
@@ -146,7 +146,7 @@ class TestSmoke:
(
transforms
.
RandomZoomOut
(
p
=
1.0
),
None
),
(
transforms
.
RandomZoomOut
(
p
=
1.0
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
Resize
([
16
,
16
],
antialias
=
True
),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
)),
None
),
(
transforms
.
ScaleJitter
((
16
,
16
),
scale_range
=
(
0.8
,
1.2
)),
None
),
(
transforms
.
ClampBoundingBox
es
(),
None
),
(
transforms
.
ClampBoundingBox
(),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertBoundingBoxFormat
(
datapoints
.
BoundingBoxFormat
.
CXCYWH
),
None
),
(
transforms
.
ConvertDtype
(),
None
),
(
transforms
.
ConvertDtype
(),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
(
transforms
.
GaussianBlur
(
kernel_size
=
3
),
None
),
...
...
test/test_prototype_transforms_functional.py
View file @
602e8ca1
...
@@ -25,7 +25,7 @@ from torch.utils._pytree import tree_map
...
@@ -25,7 +25,7 @@ from torch.utils._pytree import tree_map
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
from
torchvision.prototype.transforms
import
functional
as
F
from
torchvision.prototype.transforms
import
functional
as
F
from
torchvision.prototype.transforms.functional._geometry
import
_center_crop_compute_padding
from
torchvision.prototype.transforms.functional._geometry
import
_center_crop_compute_padding
from
torchvision.prototype.transforms.functional._meta
import
convert_format_bounding_box
from
torchvision.prototype.transforms.functional._meta
import
clamp_bounding_box
,
convert_format_bounding_box
from
torchvision.transforms.functional
import
_get_perspective_coeffs
from
torchvision.transforms.functional
import
_get_perspective_coeffs
...
@@ -257,16 +257,17 @@ class TestKernels:
...
@@ -257,16 +257,17 @@ class TestKernels:
@
reference_inputs
@
reference_inputs
def
test_against_reference
(
self
,
test_id
,
info
,
args_kwargs
):
def
test_against_reference
(
self
,
test_id
,
info
,
args_kwargs
):
(
input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
"cpu"
)
(
input
,
*
other_args
),
kwargs
=
args_kwargs
.
load
(
"cpu"
)
input
=
input
.
as_subclass
(
torch
.
Tensor
)
actual
=
info
.
kernel
(
input
,
*
other_args
,
**
kwargs
)
actual
=
info
.
kernel
(
input
.
as_subclass
(
torch
.
Tensor
),
*
other_args
,
**
kwargs
)
# We intnetionally don't unwrap the input of the reference function in order for it to have access to all
# metadata regardless of whether the kernel takes it explicitly or not
expected
=
info
.
reference_fn
(
input
,
*
other_args
,
**
kwargs
)
expected
=
info
.
reference_fn
(
input
,
*
other_args
,
**
kwargs
)
assert_close
(
assert_close
(
actual
,
actual
,
expected
,
expected
,
**
info
.
get_closeness_kwargs
(
test_id
,
dtype
=
input
.
dtype
,
device
=
input
.
device
),
**
info
.
get_closeness_kwargs
(
test_id
,
dtype
=
input
.
dtype
,
device
=
input
.
device
),
msg
=
parametrized_error_message
(
*
other_args
,
**
kwargs
),
msg
=
parametrized_error_message
(
input
,
*
other_args
,
**
kwargs
),
)
)
@
make_info_args_kwargs_parametrization
(
@
make_info_args_kwargs_parametrization
(
...
@@ -682,6 +683,10 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
...
@@ -682,6 +683,10 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
(
48.56528888843238
,
9.611532109828834
,
53.35347829361575
,
14.39972151501221
),
(
48.56528888843238
,
9.611532109828834
,
53.35347829361575
,
14.39972151501221
),
]
]
expected_bboxes
=
clamp_bounding_box
(
datapoints
.
BoundingBox
(
expected_bboxes
,
format
=
"XYXY"
,
spatial_size
=
spatial_size
)
).
tolist
()
output_boxes
=
F
.
affine_bounding_box
(
output_boxes
=
F
.
affine_bounding_box
(
in_boxes
,
in_boxes
,
format
=
format
,
format
=
format
,
...
@@ -762,7 +767,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
...
@@ -762,7 +767,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
dtype
=
bbox
.
dtype
,
dtype
=
bbox
.
dtype
,
device
=
bbox
.
device
,
device
=
bbox
.
device
,
)
)
return
convert_format_bounding_box
(
out_bbox
,
new_format
=
bbox
.
format
),
(
height
,
width
)
out_bbox
=
clamp_bounding_box
(
convert_format_bounding_box
(
out_bbox
,
new_format
=
bbox
.
format
))
return
out_bbox
,
(
height
,
width
)
spatial_size
=
(
32
,
38
)
spatial_size
=
(
32
,
38
)
...
@@ -839,6 +845,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
...
@@ -839,6 +845,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[
69.27564928
,
12.39339828
,
74.93250353
,
18.05025253
],
[
69.27564928
,
12.39339828
,
74.93250353
,
18.05025253
],
[
18.36396103
,
1.07968978
,
46.64823228
,
29.36396103
],
[
18.36396103
,
1.07968978
,
46.64823228
,
29.36396103
],
]
]
expected_bboxes
=
clamp_bounding_box
(
datapoints
.
BoundingBox
(
expected_bboxes
,
format
=
"XYXY"
,
spatial_size
=
spatial_size
)
).
tolist
()
output_boxes
,
_
=
F
.
rotate_bounding_box
(
output_boxes
,
_
=
F
.
rotate_bounding_box
(
in_boxes
,
in_boxes
,
...
@@ -905,6 +914,10 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
...
@@ -905,6 +914,10 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
if
format
!=
datapoints
.
BoundingBoxFormat
.
XYXY
:
if
format
!=
datapoints
.
BoundingBoxFormat
.
XYXY
:
in_boxes
=
convert_format_bounding_box
(
in_boxes
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
)
in_boxes
=
convert_format_bounding_box
(
in_boxes
,
datapoints
.
BoundingBoxFormat
.
XYXY
,
format
)
expected_bboxes
=
clamp_bounding_box
(
datapoints
.
BoundingBox
(
expected_bboxes
,
format
=
"XYXY"
,
spatial_size
=
spatial_size
)
).
tolist
()
output_boxes
,
output_spatial_size
=
F
.
crop_bounding_box
(
output_boxes
,
output_spatial_size
=
F
.
crop_bounding_box
(
in_boxes
,
in_boxes
,
format
,
format
,
...
@@ -1121,7 +1134,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
...
@@ -1121,7 +1134,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
dtype
=
bbox
.
dtype
,
dtype
=
bbox
.
dtype
,
device
=
bbox
.
device
,
device
=
bbox
.
device
,
)
)
return
convert_format_bounding_box
(
out_bbox
,
new_format
=
bbox
.
format
)
return
clamp_bounding_box
(
convert_format_bounding_box
(
out_bbox
,
new_format
=
bbox
.
format
)
)
spatial_size
=
(
32
,
38
)
spatial_size
=
(
32
,
38
)
...
@@ -1134,6 +1147,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
...
@@ -1134,6 +1147,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
output_bboxes
=
F
.
perspective_bounding_box
(
output_bboxes
=
F
.
perspective_bounding_box
(
bboxes
.
as_subclass
(
torch
.
Tensor
),
bboxes
.
as_subclass
(
torch
.
Tensor
),
format
=
bboxes
.
format
,
format
=
bboxes
.
format
,
spatial_size
=
bboxes
.
spatial_size
,
startpoints
=
None
,
startpoints
=
None
,
endpoints
=
None
,
endpoints
=
None
,
coefficients
=
pcoeffs
,
coefficients
=
pcoeffs
,
...
@@ -1178,6 +1192,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
...
@@ -1178,6 +1192,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
]
]
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
torch
.
tensor
(
out_bbox
)
out_bbox
=
convert_format_bounding_box
(
out_bbox
,
datapoints
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
convert_format_bounding_box
(
out_bbox
,
datapoints
.
BoundingBoxFormat
.
XYWH
,
format_
)
out_bbox
=
clamp_bounding_box
(
out_bbox
,
format
=
format_
,
spatial_size
=
output_size
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
return
out_bbox
.
to
(
dtype
=
dtype
,
device
=
bbox
.
device
)
for
bboxes
in
make_bounding_boxes
(
extra_dims
=
((
4
,),)):
for
bboxes
in
make_bounding_boxes
(
extra_dims
=
((
4
,),)):
...
@@ -1201,7 +1216,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
...
@@ -1201,7 +1216,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
expected_bboxes
=
torch
.
stack
(
expected_bboxes
)
else
:
else
:
expected_bboxes
=
expected_bboxes
[
0
]
expected_bboxes
=
expected_bboxes
[
0
]
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
)
torch
.
testing
.
assert_close
(
output_boxes
,
expected_bboxes
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
output_spatial_size
,
output_size
)
torch
.
testing
.
assert_close
(
output_spatial_size
,
output_size
)
...
...
torchvision/prototype/datapoints/_bounding_box.py
View file @
602e8ca1
...
@@ -81,7 +81,10 @@ class BoundingBox(Datapoint):
...
@@ -81,7 +81,10 @@ class BoundingBox(Datapoint):
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
antialias
:
Optional
[
Union
[
str
,
bool
]]
=
"warn"
,
)
->
BoundingBox
:
)
->
BoundingBox
:
output
,
spatial_size
=
self
.
_F
.
resize_bounding_box
(
output
,
spatial_size
=
self
.
_F
.
resize_bounding_box
(
self
.
as_subclass
(
torch
.
Tensor
),
spatial_size
=
self
.
spatial_size
,
size
=
size
,
max_size
=
max_size
self
.
as_subclass
(
torch
.
Tensor
),
spatial_size
=
self
.
spatial_size
,
size
=
size
,
max_size
=
max_size
,
)
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
return
BoundingBox
.
wrap_like
(
self
,
output
,
spatial_size
=
spatial_size
)
...
@@ -178,6 +181,7 @@ class BoundingBox(Datapoint):
...
@@ -178,6 +181,7 @@ class BoundingBox(Datapoint):
output
=
self
.
_F
.
perspective_bounding_box
(
output
=
self
.
_F
.
perspective_bounding_box
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
as_subclass
(
torch
.
Tensor
),
format
=
self
.
format
,
format
=
self
.
format
,
spatial_size
=
self
.
spatial_size
,
startpoints
=
startpoints
,
startpoints
=
startpoints
,
endpoints
=
endpoints
,
endpoints
=
endpoints
,
coefficients
=
coefficients
,
coefficients
=
coefficients
,
...
@@ -190,5 +194,7 @@ class BoundingBox(Datapoint):
...
@@ -190,5 +194,7 @@ class BoundingBox(Datapoint):
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
interpolation
:
InterpolationMode
=
InterpolationMode
.
BILINEAR
,
fill
:
FillTypeJIT
=
None
,
fill
:
FillTypeJIT
=
None
,
)
->
BoundingBox
:
)
->
BoundingBox
:
output
=
self
.
_F
.
elastic_bounding_box
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
displacement
)
output
=
self
.
_F
.
elastic_bounding_box
(
self
.
as_subclass
(
torch
.
Tensor
),
self
.
format
,
self
.
spatial_size
,
displacement
=
displacement
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
return
BoundingBox
.
wrap_like
(
self
,
output
)
torchvision/prototype/transforms/__init__.py
View file @
602e8ca1
...
@@ -41,7 +41,7 @@ from ._geometry import (
...
@@ -41,7 +41,7 @@ from ._geometry import (
ScaleJitter
,
ScaleJitter
,
TenCrop
,
TenCrop
,
)
)
from
._meta
import
ClampBoundingBox
es
,
ConvertBoundingBoxFormat
,
ConvertDtype
,
ConvertImageDtype
from
._meta
import
ClampBoundingBox
,
ConvertBoundingBoxFormat
,
ConvertDtype
,
ConvertImageDtype
from
._misc
import
(
from
._misc
import
(
GaussianBlur
,
GaussianBlur
,
Identity
,
Identity
,
...
...
torchvision/prototype/transforms/_meta.py
View file @
602e8ca1
...
@@ -42,7 +42,7 @@ class ConvertDtype(Transform):
...
@@ -42,7 +42,7 @@ class ConvertDtype(Transform):
ConvertImageDtype
=
ConvertDtype
ConvertImageDtype
=
ConvertDtype
class
ClampBoundingBox
es
(
Transform
):
class
ClampBoundingBox
(
Transform
):
_transformed_types
=
(
datapoints
.
BoundingBox
,)
_transformed_types
=
(
datapoints
.
BoundingBox
,)
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
...
...
torchvision/prototype/transforms/functional/_geometry.py
View file @
602e8ca1
...
@@ -22,7 +22,7 @@ from torchvision.transforms.functional_tensor import _pad_symmetric
...
@@ -22,7 +22,7 @@ from torchvision.transforms.functional_tensor import _pad_symmetric
from
torchvision.utils
import
_log_api_usage_once
from
torchvision.utils
import
_log_api_usage_once
from
._meta
import
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._meta
import
clamp_bounding_box
,
convert_format_bounding_box
,
get_spatial_size_image_pil
from
._utils
import
is_simple_tensor
from
._utils
import
is_simple_tensor
...
@@ -580,8 +580,9 @@ def affine_image_pil(
...
@@ -580,8 +580,9 @@ def affine_image_pil(
return
_FP
.
affine
(
image
,
matrix
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
return
_FP
.
affine
(
image
,
matrix
,
interpolation
=
pil_modes_mapping
[
interpolation
],
fill
=
fill
)
def
_affine_bounding_box_
xyxy
(
def
_affine_bounding_box_
with_expand
(
bounding_box
:
torch
.
Tensor
,
bounding_box
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
angle
:
Union
[
int
,
float
],
angle
:
Union
[
int
,
float
],
translate
:
List
[
float
],
translate
:
List
[
float
],
...
@@ -593,6 +594,17 @@ def _affine_bounding_box_xyxy(
...
@@ -593,6 +594,17 @@ def _affine_bounding_box_xyxy(
if
bounding_box
.
numel
()
==
0
:
if
bounding_box
.
numel
()
==
0
:
return
bounding_box
,
spatial_size
return
bounding_box
,
spatial_size
original_shape
=
bounding_box
.
shape
original_dtype
=
bounding_box
.
dtype
bounding_box
=
bounding_box
.
clone
()
if
bounding_box
.
is_floating_point
()
else
bounding_box
.
float
()
dtype
=
bounding_box
.
dtype
device
=
bounding_box
.
device
bounding_box
=
(
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
).
reshape
(
-
1
,
4
)
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
shear
,
center
=
_affine_parse_args
(
angle
,
translate
,
scale
,
shear
,
InterpolationMode
.
NEAREST
,
center
angle
,
translate
,
scale
,
shear
,
InterpolationMode
.
NEAREST
,
center
)
)
...
@@ -601,9 +613,6 @@ def _affine_bounding_box_xyxy(
...
@@ -601,9 +613,6 @@ def _affine_bounding_box_xyxy(
height
,
width
=
spatial_size
height
,
width
=
spatial_size
center
=
[
width
*
0.5
,
height
*
0.5
]
center
=
[
width
*
0.5
,
height
*
0.5
]
dtype
=
bounding_box
.
dtype
if
torch
.
is_floating_point
(
bounding_box
)
else
torch
.
float32
device
=
bounding_box
.
device
affine_vector
=
_get_inverse_affine_matrix
(
center
,
angle
,
translate
,
scale
,
shear
,
inverted
=
False
)
affine_vector
=
_get_inverse_affine_matrix
(
center
,
angle
,
translate
,
scale
,
shear
,
inverted
=
False
)
transposed_affine_matrix
=
(
transposed_affine_matrix
=
(
torch
.
tensor
(
torch
.
tensor
(
...
@@ -651,7 +660,13 @@ def _affine_bounding_box_xyxy(
...
@@ -651,7 +660,13 @@ def _affine_bounding_box_xyxy(
new_width
,
new_height
=
_compute_affine_output_size
(
affine_vector
,
width
,
height
)
new_width
,
new_height
=
_compute_affine_output_size
(
affine_vector
,
width
,
height
)
spatial_size
=
(
new_height
,
new_width
)
spatial_size
=
(
new_height
,
new_width
)
return
out_bboxes
.
to
(
bounding_box
.
dtype
),
spatial_size
out_bboxes
=
clamp_bounding_box
(
out_bboxes
,
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
)
out_bboxes
=
convert_format_bounding_box
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
)
out_bboxes
=
out_bboxes
.
to
(
original_dtype
)
return
out_bboxes
,
spatial_size
def
affine_bounding_box
(
def
affine_bounding_box
(
...
@@ -664,19 +679,18 @@ def affine_bounding_box(
...
@@ -664,19 +679,18 @@ def affine_bounding_box(
shear
:
List
[
float
],
shear
:
List
[
float
],
center
:
Optional
[
List
[
float
]]
=
None
,
center
:
Optional
[
List
[
float
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
original_shape
=
bounding_box
.
shape
out_box
,
_
=
_affine_bounding_box_with_expand
(
bounding_box
,
bounding_box
=
(
format
=
format
,
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
spatial_size
=
spatial_size
,
).
reshape
(
-
1
,
4
)
angle
=
angle
,
translate
=
translate
,
out_bboxes
,
_
=
_affine_bounding_box_xyxy
(
bounding_box
,
spatial_size
,
angle
,
translate
,
scale
,
shear
,
center
)
scale
=
scale
,
shear
=
shear
,
# out_bboxes should be of shape [N boxes, 4]
center
=
center
,
expand
=
False
,
return
convert_format_bounding_box
(
)
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
return
out_box
).
reshape
(
original_shape
)
def
affine_mask
(
def
affine_mask
(
...
@@ -852,14 +866,10 @@ def rotate_bounding_box(
...
@@ -852,14 +866,10 @@ def rotate_bounding_box(
warnings
.
warn
(
"The provided center argument has no effect on the result if expand is True"
)
warnings
.
warn
(
"The provided center argument has no effect on the result if expand is True"
)
center
=
None
center
=
None
original_shape
=
bounding_box
.
shape
return
_affine_bounding_box_with_expand
(
bounding_box
=
(
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
out_bboxes
,
spatial_size
=
_affine_bounding_box_xyxy
(
bounding_box
,
bounding_box
,
spatial_size
,
format
=
format
,
spatial_size
=
spatial_size
,
angle
=-
angle
,
angle
=-
angle
,
translate
=
[
0.0
,
0.0
],
translate
=
[
0.0
,
0.0
],
scale
=
1.0
,
scale
=
1.0
,
...
@@ -868,13 +878,6 @@ def rotate_bounding_box(
...
@@ -868,13 +878,6 @@ def rotate_bounding_box(
expand
=
expand
,
expand
=
expand
,
)
)
return
(
convert_format_bounding_box
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
).
reshape
(
original_shape
),
spatial_size
,
)
def
rotate_mask
(
def
rotate_mask
(
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
...
@@ -1112,8 +1115,9 @@ def pad_bounding_box(
...
@@ -1112,8 +1115,9 @@ def pad_bounding_box(
height
,
width
=
spatial_size
height
,
width
=
spatial_size
height
+=
top
+
bottom
height
+=
top
+
bottom
width
+=
left
+
right
width
+=
left
+
right
spatial_size
=
(
height
,
width
)
return
bounding_box
,
(
height
,
width
)
return
clamp_
bounding_box
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
),
spatial_size
def
pad_video
(
def
pad_video
(
...
@@ -1185,8 +1189,9 @@ def crop_bounding_box(
...
@@ -1185,8 +1189,9 @@ def crop_bounding_box(
sub
=
[
left
,
top
,
0
,
0
]
sub
=
[
left
,
top
,
0
,
0
]
bounding_box
=
bounding_box
-
torch
.
tensor
(
sub
,
dtype
=
bounding_box
.
dtype
,
device
=
bounding_box
.
device
)
bounding_box
=
bounding_box
-
torch
.
tensor
(
sub
,
dtype
=
bounding_box
.
dtype
,
device
=
bounding_box
.
device
)
spatial_size
=
(
height
,
width
)
return
bounding_box
,
(
height
,
width
)
return
clamp_
bounding_box
(
bounding_box
,
format
=
format
,
spatial_size
=
spatial_size
),
spatial_size
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
def
crop_mask
(
mask
:
torch
.
Tensor
,
top
:
int
,
left
:
int
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
...
@@ -1332,6 +1337,7 @@ def perspective_image_pil(
...
@@ -1332,6 +1337,7 @@ def perspective_image_pil(
def
perspective_bounding_box
(
def
perspective_bounding_box
(
bounding_box
:
torch
.
Tensor
,
bounding_box
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
startpoints
:
Optional
[
List
[
List
[
int
]]],
startpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
endpoints
:
Optional
[
List
[
List
[
int
]]],
coefficients
:
Optional
[
List
[
float
]]
=
None
,
coefficients
:
Optional
[
List
[
float
]]
=
None
,
...
@@ -1342,6 +1348,7 @@ def perspective_bounding_box(
...
@@ -1342,6 +1348,7 @@ def perspective_bounding_box(
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
perspective_coeffs
=
_perspective_coefficients
(
startpoints
,
endpoints
,
coefficients
)
original_shape
=
bounding_box
.
shape
original_shape
=
bounding_box
.
shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
bounding_box
=
(
bounding_box
=
(
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
).
reshape
(
-
1
,
4
)
...
@@ -1408,7 +1415,11 @@ def perspective_bounding_box(
...
@@ -1408,7 +1415,11 @@ def perspective_bounding_box(
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bboxes
=
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
.
dtype
)
out_bboxes
=
clamp_bounding_box
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
.
dtype
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
)
# out_bboxes should be of shape [N boxes, 4]
# out_bboxes should be of shape [N boxes, 4]
...
@@ -1549,6 +1560,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
...
@@ -1549,6 +1560,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def
elastic_bounding_box
(
def
elastic_bounding_box
(
bounding_box
:
torch
.
Tensor
,
bounding_box
:
torch
.
Tensor
,
format
:
datapoints
.
BoundingBoxFormat
,
format
:
datapoints
.
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
],
displacement
:
torch
.
Tensor
,
displacement
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
bounding_box
.
numel
()
==
0
:
if
bounding_box
.
numel
()
==
0
:
...
@@ -1562,14 +1574,11 @@ def elastic_bounding_box(
...
@@ -1562,14 +1574,11 @@ def elastic_bounding_box(
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
displacement
=
displacement
.
to
(
dtype
=
dtype
,
device
=
device
)
original_shape
=
bounding_box
.
shape
original_shape
=
bounding_box
.
shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
bounding_box
=
(
bounding_box
=
(
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
convert_format_bounding_box
(
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
)
).
reshape
(
-
1
,
4
)
).
reshape
(
-
1
,
4
)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add spatial_size arg and check displacement shape
spatial_size
=
displacement
.
shape
[
-
3
],
displacement
.
shape
[
-
2
]
id_grid
=
_create_identity_grid
(
spatial_size
,
device
=
device
,
dtype
=
dtype
)
id_grid
=
_create_identity_grid
(
spatial_size
,
device
=
device
,
dtype
=
dtype
)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
# This is not an exact inverse of the grid
...
@@ -1588,7 +1597,11 @@ def elastic_bounding_box(
...
@@ -1588,7 +1597,11 @@ def elastic_bounding_box(
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
transformed_points
=
transformed_points
.
reshape
(
-
1
,
4
,
2
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bbox_mins
,
out_bbox_maxs
=
torch
.
aminmax
(
transformed_points
,
dim
=
1
)
out_bboxes
=
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
.
dtype
)
out_bboxes
=
clamp_bounding_box
(
torch
.
cat
([
out_bbox_mins
,
out_bbox_maxs
],
dim
=
1
).
to
(
bounding_box
.
dtype
),
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
spatial_size
,
)
return
convert_format_bounding_box
(
return
convert_format_bounding_box
(
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
out_bboxes
,
old_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
...
@@ -1796,7 +1809,7 @@ def resized_crop_bounding_box(
...
@@ -1796,7 +1809,7 @@ def resized_crop_bounding_box(
size
:
List
[
int
],
size
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
bounding_box
,
_
=
crop_bounding_box
(
bounding_box
,
format
,
top
,
left
,
height
,
width
)
bounding_box
,
_
=
crop_bounding_box
(
bounding_box
,
format
,
top
,
left
,
height
,
width
)
return
resize_bounding_box
(
bounding_box
,
(
height
,
width
),
size
)
return
resize_bounding_box
(
bounding_box
,
spatial_size
=
(
height
,
width
),
size
=
size
)
def
resized_crop_mask
(
def
resized_crop_mask
(
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
602e8ca1
...
@@ -245,12 +245,17 @@ def _clamp_bounding_box(
...
@@ -245,12 +245,17 @@ def _clamp_bounding_box(
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
# BoundingBoxFormat instead of converting back and forth
in_dtype
=
bounding_box
.
dtype
bounding_box
=
bounding_box
.
clone
()
if
bounding_box
.
is_floating_point
()
else
bounding_box
.
float
()
xyxy_boxes
=
convert_format_bounding_box
(
xyxy_boxes
=
convert_format_bounding_box
(
bounding_box
.
clone
()
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
bounding_box
,
old_format
=
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
inplace
=
True
)
)
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
1
])
xyxy_boxes
[...,
0
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
1
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
0
])
xyxy_boxes
[...,
1
::
2
].
clamp_
(
min
=
0
,
max
=
spatial_size
[
0
])
return
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
out_boxes
=
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
return
out_boxes
.
to
(
in_dtype
)
def
clamp_bounding_box
(
def
clamp_bounding_box
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment