Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
ea37cd38
Unverified
Commit
ea37cd38
authored
Feb 13, 2023
by
Philip Meier
Committed by
GitHub
Feb 13, 2023
Browse files
make convert_format_bounding_box a hybrid kernel dispatcher (#7228)
parent
0316ed10
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
94 additions
and
21 deletions
+94
-21
test/prototype_common_utils.py
test/prototype_common_utils.py
+7
-0
test/prototype_transforms_kernel_infos.py
test/prototype_transforms_kernel_infos.py
+25
-10
test/test_prototype_transforms_functional.py
test/test_prototype_transforms_functional.py
+29
-2
torchvision/prototype/transforms/_meta.py
torchvision/prototype/transforms/_meta.py
+1
-6
torchvision/prototype/transforms/functional/_meta.py
torchvision/prototype/transforms/functional/_meta.py
+32
-3
No files found.
test/prototype_common_utils.py
View file @
ea37cd38
...
@@ -237,6 +237,13 @@ class TensorLoader:
...
@@ -237,6 +237,13 @@ class TensorLoader:
def
load
(
self
,
device
):
def
load
(
self
,
device
):
return
self
.
fn
(
self
.
shape
,
self
.
dtype
,
device
)
return
self
.
fn
(
self
.
shape
,
self
.
dtype
,
device
)
def
unwrap
(
self
):
return
TensorLoader
(
fn
=
lambda
shape
,
dtype
,
device
:
self
.
fn
(
shape
,
dtype
,
device
).
as_subclass
(
torch
.
Tensor
),
shape
=
self
.
shape
,
dtype
=
self
.
dtype
,
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ImageLoader
(
TensorLoader
):
class
ImageLoader
(
TensorLoader
):
...
...
test/prototype_transforms_kernel_infos.py
View file @
ea37cd38
...
@@ -26,7 +26,6 @@ from prototype_common_utils import (
...
@@ -26,7 +26,6 @@ from prototype_common_utils import (
make_video_loader
,
make_video_loader
,
make_video_loaders
,
make_video_loaders
,
mark_framework_limitation
,
mark_framework_limitation
,
TensorLoader
,
TestMark
,
TestMark
,
)
)
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
...
@@ -660,7 +659,8 @@ KERNEL_INFOS.extend(
...
@@ -660,7 +659,8 @@ KERNEL_INFOS.extend(
def
sample_inputs_convert_format_bounding_box
():
def
sample_inputs_convert_format_bounding_box
():
formats
=
list
(
datapoints
.
BoundingBoxFormat
)
formats
=
list
(
datapoints
.
BoundingBoxFormat
)
for
bounding_box_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
for
bounding_box_loader
,
new_format
in
itertools
.
product
(
make_bounding_box_loaders
(
formats
=
formats
),
formats
):
yield
ArgsKwargs
(
bounding_box_loader
,
old_format
=
bounding_box_loader
.
format
,
new_format
=
new_format
)
yield
ArgsKwargs
(
bounding_box_loader
,
new_format
=
new_format
)
yield
ArgsKwargs
(
bounding_box_loader
.
unwrap
(),
old_format
=
bounding_box_loader
.
format
,
new_format
=
new_format
)
def
reference_convert_format_bounding_box
(
bounding_box
,
old_format
,
new_format
):
def
reference_convert_format_bounding_box
(
bounding_box
,
old_format
,
new_format
):
...
@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
...
@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
def
reference_inputs_convert_format_bounding_box
():
def
reference_inputs_convert_format_bounding_box
():
for
args_kwargs
in
sample_inputs_convert_format_bounding_box
():
for
args_kwargs
in
sample_inputs_convert_format_bounding_box
():
if
len
(
args_kwargs
.
args
[
0
].
shape
)
==
2
:
if
len
(
args_kwargs
.
args
[
0
].
shape
)
!=
2
:
yield
args_kwargs
continue
(
loader
,
*
other_args
),
kwargs
=
args_kwargs
if
isinstance
(
loader
,
BoundingBoxLoader
):
kwargs
[
"old_format"
]
=
loader
.
format
loader
=
loader
.
unwrap
()
yield
ArgsKwargs
(
loader
,
*
other_args
,
**
kwargs
)
KERNEL_INFOS
.
append
(
KERNEL_INFOS
.
append
(
...
@@ -682,6 +688,18 @@ KERNEL_INFOS.append(
...
@@ -682,6 +688,18 @@ KERNEL_INFOS.append(
reference_fn
=
reference_convert_format_bounding_box
,
reference_fn
=
reference_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
reference_inputs_fn
=
reference_inputs_convert_format_bounding_box
,
logs_usage
=
True
,
logs_usage
=
True
,
test_marks
=
[
mark_framework_limitation
(
(
"TestKernels"
,
"test_scripted_vs_eager"
),
reason
=
(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition
=
lambda
arg_kwargs
:
isinstance
(
arg_kwargs
.
args
[
0
],
BoundingBoxLoader
)
and
arg_kwargs
.
kwargs
.
get
(
"old_format"
)
is
None
,
)
],
),
),
)
)
...
@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
...
@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
for
bounding_box_loader
in
make_bounding_box_loaders
():
for
bounding_box_loader
in
make_bounding_box_loaders
():
yield
ArgsKwargs
(
bounding_box_loader
)
yield
ArgsKwargs
(
bounding_box_loader
)
simple_tensor_loader
=
TensorLoader
(
fn
=
lambda
shape
,
dtype
,
device
:
bounding_box_loader
.
fn
(
shape
,
dtype
,
device
).
as_subclass
(
torch
.
Tensor
),
shape
=
bounding_box_loader
.
shape
,
dtype
=
bounding_box_loader
.
dtype
,
)
yield
ArgsKwargs
(
yield
ArgsKwargs
(
simple_tensor_loader
,
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
bounding_box_loader
.
unwrap
(),
format
=
bounding_box_loader
.
format
,
spatial_size
=
bounding_box_loader
.
spatial_size
,
)
)
...
...
test/test_prototype_transforms_functional.py
View file @
ea37cd38
...
@@ -572,7 +572,7 @@ class TestClampBoundingBox:
...
@@ -572,7 +572,7 @@ class TestClampBoundingBox:
def
test_simple_tensor_insufficient_metadata
(
self
,
metadata
):
def
test_simple_tensor_insufficient_metadata
(
self
,
metadata
):
simple_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
simple_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
"simple tensor"
):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `spatial_size` has to be passed"
)
):
F
.
clamp_bounding_box
(
simple_tensor
,
**
metadata
)
F
.
clamp_bounding_box
(
simple_tensor
,
**
metadata
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -586,10 +586,37 @@ class TestClampBoundingBox:
...
@@ -586,10 +586,37 @@ class TestClampBoundingBox:
def
test_datapoint_explicit_metadata
(
self
,
metadata
):
def
test_datapoint_explicit_metadata
(
self
,
metadata
):
datapoint
=
next
(
make_bounding_boxes
())
datapoint
=
next
(
make_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
"bounding box datapoint"
):
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`format` and `spatial_size` must not be passed"
)
):
F
.
clamp_bounding_box
(
datapoint
,
**
metadata
)
F
.
clamp_bounding_box
(
datapoint
,
**
metadata
)
class
TestConvertFormatBoundingBox
:
@
pytest
.
mark
.
parametrize
(
(
"inpt"
,
"old_format"
),
[
(
next
(
make_bounding_boxes
()),
None
),
(
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
),
datapoints
.
BoundingBoxFormat
.
XYXY
),
],
)
def
test_missing_new_format
(
self
,
inpt
,
old_format
):
with
pytest
.
raises
(
TypeError
,
match
=
re
.
escape
(
"missing 1 required argument: 'new_format'"
)):
F
.
convert_format_bounding_box
(
inpt
,
old_format
)
def
test_simple_tensor_insufficient_metadata
(
self
):
simple_tensor
=
next
(
make_bounding_boxes
()).
as_subclass
(
torch
.
Tensor
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` has to be passed"
)):
F
.
convert_format_bounding_box
(
simple_tensor
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
def
test_datapoint_explicit_metadata
(
self
):
datapoint
=
next
(
make_bounding_boxes
())
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"`old_format` must not be passed"
)):
F
.
convert_format_bounding_box
(
datapoint
,
old_format
=
datapoint
.
format
,
new_format
=
datapoints
.
BoundingBoxFormat
.
CXCYWH
)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
# `prototype_transforms_kernel_infos.py`
...
...
torchvision/prototype/transforms/_meta.py
View file @
ea37cd38
...
@@ -19,12 +19,7 @@ class ConvertBoundingBoxFormat(Transform):
...
@@ -19,12 +19,7 @@ class ConvertBoundingBoxFormat(Transform):
self
.
format
=
format
self
.
format
=
format
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
def
_transform
(
self
,
inpt
:
datapoints
.
BoundingBox
,
params
:
Dict
[
str
,
Any
])
->
datapoints
.
BoundingBox
:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
return
F
.
convert_format_bounding_box
(
inpt
,
new_format
=
self
.
format
)
# type: ignore[return-value]
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output
=
F
.
convert_format_bounding_box
(
inpt
.
as_subclass
(
torch
.
Tensor
),
old_format
=
inpt
.
format
,
new_format
=
params
[
"format"
]
)
return
datapoints
.
BoundingBox
.
wrap_like
(
inpt
,
output
,
format
=
params
[
"format"
])
class
ConvertDtype
(
Transform
):
class
ConvertDtype
(
Transform
):
...
...
torchvision/prototype/transforms/functional/_meta.py
View file @
ea37cd38
...
@@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
...
@@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
return
xyxy
return
xyxy
def
convert_format_bounding_box
(
def
_
convert_format_bounding_box
(
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
bounding_box
:
torch
.
Tensor
,
old_format
:
BoundingBoxFormat
,
new_format
:
BoundingBoxFormat
,
inplace
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_box
)
if
new_format
==
old_format
:
if
new_format
==
old_format
:
return
bounding_box
return
bounding_box
...
@@ -209,6 +207,37 @@ def convert_format_bounding_box(
...
@@ -209,6 +207,37 @@ def convert_format_bounding_box(
return
bounding_box
return
bounding_box
def
convert_format_bounding_box
(
inpt
:
datapoints
.
InputTypeJIT
,
old_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
new_format
:
Optional
[
BoundingBoxFormat
]
=
None
,
inplace
:
bool
=
False
,
)
->
datapoints
.
InputTypeJIT
:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if
new_format
is
None
:
raise
TypeError
(
"convert_format_bounding_box() missing 1 required argument: 'new_format'"
)
if
not
torch
.
jit
.
is_scripting
():
_log_api_usage_once
(
convert_format_bounding_box
)
if
torch
.
jit
.
is_scripting
()
or
is_simple_tensor
(
inpt
):
if
old_format
is
None
:
raise
ValueError
(
"For simple tensor inputs, `old_format` has to be passed."
)
return
_convert_format_bounding_box
(
inpt
,
old_format
=
old_format
,
new_format
=
new_format
,
inplace
=
inplace
)
elif
isinstance
(
inpt
,
datapoints
.
BoundingBox
):
if
old_format
is
not
None
:
raise
ValueError
(
"For bounding box datapoint inputs, `old_format` must not be passed."
)
output
=
_convert_format_bounding_box
(
inpt
,
old_format
=
inpt
.
format
,
new_format
=
new_format
,
inplace
=
inplace
)
return
datapoints
.
BoundingBox
.
wrap_like
(
inpt
,
output
)
else
:
raise
TypeError
(
f
"Input can either be a plain tensor or a bounding box datapoint, but got
{
type
(
inpt
)
}
instead."
)
def
_clamp_bounding_box
(
def
_clamp_bounding_box
(
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
bounding_box
:
torch
.
Tensor
,
format
:
BoundingBoxFormat
,
spatial_size
:
Tuple
[
int
,
int
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment