Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
0316ed10
Unverified
Commit
0316ed10
authored
Feb 13, 2023
by
Philip Meier
Committed by
GitHub
Feb 13, 2023
Browse files
make clamp_bounding_box a kernel / dispatcher hybrid (#7227)
parent
2489f370
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
89 additions
and
20 deletions
+89
-20
test/prototype_common_utils.py
test/prototype_common_utils.py
+2
-2
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+23
-1
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+38
-6
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+1
-6
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+25
-5
No files found.
test/prototype_common_utils.py
View file @
0316ed10
...
@@ -640,14 +640,14 @@ class TestMark:
...
@@ -640,14 +640,14 @@ class TestMark:
self
.
condition
=
condition
or
(
lambda
args_kwargs
:
True
)
self
.
condition
=
condition
or
(
lambda
args_kwargs
:
True
)
def
mark_framework_limitation
(
test_id
,
reason
):
def
mark_framework_limitation
(
test_id
,
reason
,
condition
=
None
):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
# we are wasting CI resources for no reason for most of the time
return
TestMark
(
test_id
,
pytest
.
mark
.
skip
(
reason
=
reason
))
return
TestMark
(
test_id
,
pytest
.
mark
.
skip
(
reason
=
reason
)
,
condition
=
condition
)
class
InfoBase
:
class
InfoBase
:
...
...
test/prototype_transforms_kernel_infos.py
View file @
0316ed10
...
@@ -12,6 +12,7 @@ import torchvision.prototype.transforms.functional as F
...
@@ -12,6 +12,7 @@ import torchvision.prototype.transforms.functional as F
from
datasets_utils
import
combinations_grid
from
datasets_utils
import
combinations_grid
from
prototype_common_utils
import
(
from
prototype_common_utils
import
(
ArgsKwargs
,
ArgsKwargs
,
BoundingBoxLoader
,
get_num_channels
,
get_num_channels
,
ImageLoader
,
ImageLoader
,
InfoBase
,
InfoBase
,
...
@@ -25,6 +26,7 @@ from prototype_common_utils import (
...
@@ -25,6 +26,7 @@ from prototype_common_utils import (
make_video_loader
,
make_video_loader
,
make_video_loaders
,
make_video_loaders
,
mark_framework_limitation
,
mark_framework_limitation
,
TensorLoader
,
TestMark
,
TestMark
,
)
)
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
...
@@ -2010,8 +2012,15 @@ KERNEL_INFOS.extend(
...
@@ -2010,8 +2012,15 @@ KERNEL_INFOS.extend(
def
sample_inputs_clamp_bounding_box
():
def
sample_inputs_clamp_bounding_box
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
yield
ArgsKwargs
(
bounding_box_loader
)
simple_tensor_loader
=
TensorLoader
(
fn
=
lambda
shape
,
dtype
,
device
:
bounding_box_loader
.
fn
(
shape
,
dtype
,
device
).
as_subclass
(
torch
.
Tensor
),
shape
=
bounding_box_loader
.
shape
,
dtype
=
bounding_box_loader
.
dtype
,
)
yield
ArgsKwargs
(
yield
ArgsKwargs
(
bounding_box
_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
simple_tensor
_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
)
)
...
@@ -2020,6 +2029,19 @@ KERNEL_INFOS.append(
...
@@ -2020,6 +2029,19 @@ KERNEL_INFOS.append(
F
.
clamp_bounding_box
,
F
.
clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
sample_inputs_fn
=
sample_inputs_clamp_bounding_box
,
logs_usage
=
True
,
logs_usage
=
True
,
test_marks
=
[
mark_framework_limitation
(
(
"TestKernels"
,
"test_scripted_vs_eager"
),
reason
=
(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition
=
lambda
arg_kwargs
:
isinstance
(
arg_kwargs
.
args
[
0
],
BoundingBoxLoader
)
and
arg_kwargs
.
kwargs
.
get
(
"format"
)
is
None
and
arg_kwargs
.
kwargs
.
get
(
"spatial_size"
)
is
None
,
)
],
)
)
)
)
...
...
test/test_prototype_transforms_functional.py
View file @
0316ed10
...
@@ -155,12 +155,14 @@ class TestKernels:
...
@@ -155,12 +155,14 @@ class TestKernels:
if
batched_tensor
.
ndim
==
data_dims
:
if
batched_tensor
.
ndim
==
data_dims
:
return
batch
return
batch
return
[
unbatcheds
=
[]
self
.
_unbatch
(
unbatched
,
data_dims
=
data_dims
)
for
unbatched
in
(
for
unbatched
in
(
batched_tensor
.
unbind
(
0
)
if
not
metadata
else
[(
t
,
*
metadata
)
for
t
in
batched_tensor
.
unbind
(
0
)]
batched_tensor
.
unbind
(
0
)
if
not
metadata
else
[(
t
,
*
metadata
)
for
t
in
batched_tensor
.
unbind
(
0
)]
):
)
if
isinstance
(
batch
,
datapoints
.
_datapoint
.
Datapoint
):
]
unbatched
=
type
(
batch
).
wrap_like
(
batch
,
unbatched
)
unbatcheds
.
append
(
self
.
_unbatch
(
unbatched
,
data_dims
=
data_dims
))
return
unbatcheds
@
sample_inputs
@
sample_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
...
@@ -558,6 +560,36 @@ def test_normalize_image_tensor_stats(device, num_channels):
...
@@ -558,6 +560,36 @@ def test_normalize_image_tensor_stats(device, num_channels):
assert_samples_from_standard_normal
(
F
.
normalize_image_tensor
(
image
,
mean
,
std
))
assert_samples_from_standard_normal
(
F
.
normalize_image_tensor
(
image
,
mean
,
std
))
class
TestClampBoundingBox
:
@
pytest
.
mark
.
parametrize
(
"metadata"
,
[
dict
(),
dict
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
),
dict
(
spatial_size
=
(
1
,
1
)),
],
)
def
test_simple_tensor_insufficient_metadata
(
self
,
metadata
):
simple_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"simple tensor"
):
F
.
clamp_bounding_box
(
simple_tensor
,
**
metadata
)
@
pytest
.
mark
.
parametrize
(
"metadata"
,
[
dict
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
),
dict
(
spatial_size
=
(
1
,
1
)),
dict
(
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
1
,
1
)),
],
)
def
test_datapoint_explicit_metadata
(
self
,
metadata
):
datapoint
=
next
(
make_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
"bounding box datapoint"
):
F
.
clamp_bounding_box
(
datapoint
,
**
metadata
)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
# `prototype_transforms_kernel_infos.py`
...
...
torchvision/prototype/transforms/_meta.py
View file @
0316ed10
...
@@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform):
...
@@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform):
_transformed_types
=
(
datapoints
.
BoundingBox
,)
_transformed_types
=
(
datapoints
.
BoundingBox
,)
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
return
F
.
clamp_bounding_box
(
inpt
)
# type: ignore[return-value]
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output
=
F
.
clamp_bounding_box
(
inpt
.
as_subclass
(
torch
.
Tensor
),
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
)
return
datapoints
.
BoundingBox
.
wrap_like
(
inpt
,
output
)
torchvision/prototype/transforms/functional/_meta.py
View file @
0316ed10
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
PIL.Image
import
PIL.Image
import
torch
import
torch
...
@@ -209,12 +209,9 @@ def convert_format_bounding_box(
...
@@ -209,12 +209,9 @@ def convert_format_bounding_box(
return
bounding_box
return
bounding_box
def
clamp_bounding_box
(
def
_
clamp_bounding_box
(
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_box
)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes
=
convert_format_bounding_box
(
xyxy_boxes
=
convert_format_bounding_box
(
...
@@ -225,6 +222,29 @@ def clamp_bounding_box(
...
@@ -225,6 +222,29 @@ def clamp_bounding_box(
return
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
return
convert_format_bounding_box
(
xyxy_boxes
,
old_format
=
BoundingBoxFormat
.
XYXY
,
new_format
=
format
,
inplace
=
True
)
def
clamp_bounding_box
(
inpt
:
datapoints
.
InputTypeJIT
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
spatial_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
datapoints
.
InputTypeJIT
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
clamp_bounding_box
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
format
is
None
or
spatial_size
is
None
:
raise
ValueError
(
"For simple tensor inputs, `format` and `spatial_size` has to be passed."
)
return
_clamp_bounding_box
(
inpt
,
format
=
format
,
spatial_size
=
spatial_size
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
if
format
is
not
None
or
spatial_size
is
not
None
:
raise
ValueError
(
"For bounding box datapoint inputs, `format` and `spatial_size` must not be passed."
)
output
=
_clamp_bounding_box
(
inpt
,
format
=
inpt
.
format
,
spatial_size
=
inpt
.
spatial_size
)
return
datapoints
.
BoundingBox
.
wrap_like
(
inpt
,
output
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a bounding box datapoint, but got
{
type
(
inpt
)
}
instead."
)
def
_num_value_bits
(
dtype
:
torch
.
dtype
)
->
int
:
def
_num_value_bits
(
dtype
:
torch
.
dtype
)
->
int
:
if
dtype
==
torch
.
uint8
:
if
dtype
==
torch
.
uint8
:
return
8
return
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment