Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
88591717
Unverified
Commit
88591717
authored
Aug 14, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 14, 2023
Browse files
Allow users to choose whether to return Datapoint subclasses or pure Tensor (#7825)
parent
3065ad59
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
259 additions
and
38 deletions
+259
-38
docs/source/datapoints.rst
docs/source/datapoints.rst
+1
-0
test/test_datapoints.py
test/test_datapoints.py
+144
-24
torchvision/datapoints/__init__.py
torchvision/datapoints/__init__.py
+2
-0
torchvision/datapoints/_bounding_box.py
torchvision/datapoints/_bounding_box.py
+32
-6
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+24
-7
torchvision/datapoints/_torch_function_helpers.py
torchvision/datapoints/_torch_function_helpers.py
+53
-0
torchvision/transforms/v2/_misc.py
torchvision/transforms/v2/_misc.py
+1
-1
torchvision/transforms/v2/functional/_utils.py
torchvision/transforms/v2/functional/_utils.py
+2
-0
No files found.
docs/source/datapoints.rst
View file @
88591717
...
@@ -18,3 +18,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
...
@@ -18,3 +18,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
BoundingBoxes
BoundingBoxes
Mask
Mask
Datapoint
Datapoint
set_return_type
test/test_datapoints.py
View file @
88591717
...
@@ -6,6 +6,20 @@ from common_utils import assert_equal
...
@@ -6,6 +6,20 @@ from common_utils import assert_equal
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
datapoints
from
torchvision
import
datapoints
from
common_utils
import
(
make_bounding_box
,
make_detection_mask
,
make_image
,
make_image_tensor
,
make_segmentation_mask
,
make_video
,
)
@
pytest
.
fixture
(
autouse
=
True
)
def
preserve_default_wrapping_behaviour
():
yield
datapoints
.
set_return_type
(
"Tensor"
)
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
rand
(
3
,
32
,
32
),
Image
.
new
(
"RGB"
,
(
32
,
32
),
color
=
123
)])
@
pytest
.
mark
.
parametrize
(
"data"
,
[
torch
.
rand
(
3
,
32
,
32
),
Image
.
new
(
"RGB"
,
(
32
,
32
),
color
=
123
)])
...
@@ -80,72 +94,88 @@ def test_to_wrapping():
...
@@ -80,72 +94,88 @@ def test_to_wrapping():
assert
image_to
.
dtype
is
torch
.
float64
assert
image_to
.
dtype
is
torch
.
float64
def
test_to_datapoint_reference
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_to_datapoint_reference
(
return_type
):
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
image
=
datapoints
.
Image
(
tensor
)
image
=
datapoints
.
Image
(
tensor
)
tensor_to
=
tensor
.
to
(
image
)
with
datapoints
.
set_return_type
(
return_type
):
tensor_to
=
tensor
.
to
(
image
)
assert
type
(
tensor_to
)
is
torch
.
Tensor
assert
type
(
tensor_to
)
is
(
datapoints
.
Image
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
assert
tensor_to
.
dtype
is
torch
.
float64
assert
tensor_to
.
dtype
is
torch
.
float64
def
test_clone_wrapping
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_clone_wrapping
(
return_type
):
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image_clone
=
image
.
clone
()
with
datapoints
.
set_return_type
(
return_type
):
image_clone
=
image
.
clone
()
assert
type
(
image_clone
)
is
datapoints
.
Image
assert
type
(
image_clone
)
is
datapoints
.
Image
assert
image_clone
.
data_ptr
()
!=
image
.
data_ptr
()
assert
image_clone
.
data_ptr
()
!=
image
.
data_ptr
()
def
test_requires_grad__wrapping
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_requires_grad__wrapping
(
return_type
):
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
assert
not
image
.
requires_grad
assert
not
image
.
requires_grad
image_requires_grad
=
image
.
requires_grad_
(
True
)
with
datapoints
.
set_return_type
(
return_type
):
image_requires_grad
=
image
.
requires_grad_
(
True
)
assert
type
(
image_requires_grad
)
is
datapoints
.
Image
assert
type
(
image_requires_grad
)
is
datapoints
.
Image
assert
image
.
requires_grad
assert
image
.
requires_grad
assert
image_requires_grad
.
requires_grad
assert
image_requires_grad
.
requires_grad
def
test_detach_wrapping
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_detach_wrapping
(
return_type
):
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
),
requires_grad
=
True
)
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
),
requires_grad
=
True
)
image_detached
=
image
.
detach
()
with
datapoints
.
set_return_type
(
return_type
):
image_detached
=
image
.
detach
()
assert
type
(
image_detached
)
is
datapoints
.
Image
assert
type
(
image_detached
)
is
datapoints
.
Image
def
test_no_wrapping_exceptions_with_metadata
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
def
test_force_subclass_with_metadata
(
return_type
):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
bbox
=
datapoints
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
bbox
=
datapoints
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
datapoints
.
set_return_type
(
return_type
)
bbox
=
bbox
.
clone
()
bbox
=
bbox
.
clone
()
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
if
return_type
==
"datapoint"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
to
(
torch
.
float64
)
bbox
=
bbox
.
to
(
torch
.
float64
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
if
return_type
==
"datapoint"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
detach
()
bbox
=
bbox
.
detach
()
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
if
return_type
==
"datapoint"
:
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
not
bbox
.
requires_grad
assert
not
bbox
.
requires_grad
bbox
.
requires_grad_
(
True
)
bbox
.
requires_grad_
(
True
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
if
return_type
==
"datapoint"
:
assert
bbox
.
requires_grad
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
requires_grad
def
test_other_op_no_wrapping
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_other_op_no_wrapping
(
return_type
):
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
with
datapoints
.
set_return_type
(
return_type
):
output
=
image
*
2
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output
=
image
*
2
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
output
)
is
(
datapoints
.
Image
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -164,19 +194,21 @@ def test_no_tensor_output_op_no_wrapping(op):
...
@@ -164,19 +194,21 @@ def test_no_tensor_output_op_no_wrapping(op):
assert
type
(
output
)
is
not
datapoints
.
Image
assert
type
(
output
)
is
not
datapoints
.
Image
def
test_inplace_op_no_wrapping
():
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_inplace_op_no_wrapping
(
return_type
):
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
output
=
image
.
add_
(
0
)
with
datapoints
.
set_return_type
(
return_type
):
output
=
image
.
add_
(
0
)
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
output
)
is
(
datapoints
.
Image
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
assert
type
(
image
)
is
datapoints
.
Image
assert
type
(
image
)
is
datapoints
.
Image
def
test_wrap_like
():
def
test_wrap_like
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
# any operation besides the ones listed in
`Datapoint._NO_WRAPPING_EXCEPTIONS`
will do here
# any operation besides the ones listed in
_FORCE_TORCHFUNCTION_SUBCLASS
will do here
output
=
image
*
2
output
=
image
*
2
image_new
=
datapoints
.
Image
.
wrap_like
(
image
,
output
)
image_new
=
datapoints
.
Image
.
wrap_like
(
image
,
output
)
...
@@ -209,3 +241,91 @@ def test_deepcopy(datapoint, requires_grad):
...
@@ -209,3 +241,91 @@ def test_deepcopy(datapoint, requires_grad):
assert
type
(
datapoint_deepcopied
)
is
type
(
datapoint
)
assert
type
(
datapoint_deepcopied
)
is
type
(
datapoint
)
assert
datapoint_deepcopied
.
requires_grad
is
requires_grad
assert
datapoint_deepcopied
.
requires_grad
is
requires_grad
@
pytest
.
mark
.
parametrize
(
"return_type"
,
[
"Tensor"
,
"datapoint"
])
def
test_operations
(
return_type
):
datapoints
.
set_return_type
(
return_type
)
img
=
datapoints
.
Image
(
torch
.
rand
(
3
,
10
,
10
))
t
=
torch
.
rand
(
3
,
10
,
10
)
mask
=
datapoints
.
Mask
(
torch
.
rand
(
1
,
10
,
10
))
for
out
in
(
[
img
+
t
,
t
+
img
,
img
*
t
,
t
*
img
,
img
+
3
,
3
+
img
,
img
*
3
,
3
*
img
,
img
+
img
,
img
.
sum
(),
img
.
reshape
(
-
1
),
img
.
float
(),
torch
.
stack
([
img
,
img
]),
]
+
list
(
torch
.
chunk
(
img
,
2
))
+
list
(
torch
.
unbind
(
img
))
):
assert
type
(
out
)
is
(
datapoints
.
Image
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
for
out
in
(
[
mask
+
t
,
t
+
mask
,
mask
*
t
,
t
*
mask
,
mask
+
3
,
3
+
mask
,
mask
*
3
,
3
*
mask
,
mask
+
mask
,
mask
.
sum
(),
mask
.
reshape
(
-
1
),
mask
.
float
(),
torch
.
stack
([
mask
,
mask
]),
]
+
list
(
torch
.
chunk
(
mask
,
2
))
+
list
(
torch
.
unbind
(
mask
))
):
assert
type
(
out
)
is
(
datapoints
.
Mask
if
return_type
==
"datapoint"
else
torch
.
Tensor
)
with
pytest
.
raises
(
TypeError
,
match
=
"unsupported operand type"
):
img
+
mask
with
pytest
.
raises
(
TypeError
,
match
=
"unsupported operand type"
):
img
*
mask
bboxes
=
datapoints
.
BoundingBoxes
(
[[
17
,
16
,
344
,
495
],
[
0
,
10
,
0
,
10
]],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
canvas_size
=
(
1000
,
1000
)
)
t
=
torch
.
rand
(
2
,
4
)
for
out
in
(
[
bboxes
+
t
,
t
+
bboxes
,
bboxes
*
t
,
t
*
bboxes
,
bboxes
+
3
,
3
+
bboxes
,
bboxes
*
3
,
3
*
bboxes
,
bboxes
+
bboxes
,
bboxes
.
sum
(),
bboxes
.
reshape
(
-
1
),
bboxes
.
float
(),
torch
.
stack
([
bboxes
,
bboxes
]),
]
+
list
(
torch
.
chunk
(
bboxes
,
2
))
+
list
(
torch
.
unbind
(
bboxes
))
):
if
return_type
==
"Tensor"
:
assert
type
(
out
)
is
torch
.
Tensor
else
:
assert
isinstance
(
out
,
datapoints
.
BoundingBoxes
)
assert
hasattr
(
out
,
"format"
)
assert
hasattr
(
out
,
"canvas_size"
)
torchvision/datapoints/__init__.py
View file @
88591717
import
torch
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
torchvision
import
_BETA_TRANSFORMS_WARNING
,
_WARN_ABOUT_BETA_TRANSFORMS
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._bounding_box
import
BoundingBoxes
,
BoundingBoxFormat
from
._datapoint
import
Datapoint
from
._datapoint
import
Datapoint
from
._image
import
Image
from
._image
import
Image
from
._mask
import
Mask
from
._mask
import
Mask
from
._torch_function_helpers
import
set_return_type
from
._video
import
Video
from
._video
import
Video
if
_WARN_ABOUT_BETA_TRANSFORMS
:
if
_WARN_ABOUT_BETA_TRANSFORMS
:
...
...
torchvision/datapoints/_bounding_box.py
View file @
88591717
from
__future__
import
annotations
from
__future__
import
annotations
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
import
torch
from
torch.utils._pytree
import
tree_flatten
from
._datapoint
import
Datapoint
from
._datapoint
import
Datapoint
...
@@ -48,11 +49,12 @@ class BoundingBoxes(Datapoint):
...
@@ -48,11 +49,12 @@ class BoundingBoxes(Datapoint):
canvas_size
:
Tuple
[
int
,
int
]
canvas_size
:
Tuple
[
int
,
int
]
@
classmethod
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
Union
[
BoundingBoxFormat
,
str
],
canvas_size
:
Tuple
[
int
,
int
])
->
BoundingBoxes
:
# type: ignore[override]
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
Union
[
BoundingBoxFormat
,
str
],
canvas_size
:
Tuple
[
int
,
int
],
check_dims
:
bool
=
True
)
->
BoundingBoxes
:
# type: ignore[override]
if
tensor
.
ndim
==
1
:
if
check_dims
:
tensor
=
tensor
.
unsqueeze
(
0
)
if
tensor
.
ndim
==
1
:
elif
tensor
.
ndim
!=
2
:
tensor
=
tensor
.
unsqueeze
(
0
)
raise
ValueError
(
f
"Expected a 1D or 2D tensor, got
{
tensor
.
ndim
}
D"
)
elif
tensor
.
ndim
!=
2
:
raise
ValueError
(
f
"Expected a 1D or 2D tensor, got
{
tensor
.
ndim
}
D"
)
if
isinstance
(
format
,
str
):
if
isinstance
(
format
,
str
):
format
=
BoundingBoxFormat
[
format
.
upper
()]
format
=
BoundingBoxFormat
[
format
.
upper
()]
bounding_boxes
=
tensor
.
as_subclass
(
cls
)
bounding_boxes
=
tensor
.
as_subclass
(
cls
)
...
@@ -99,5 +101,29 @@ class BoundingBoxes(Datapoint):
...
@@ -99,5 +101,29 @@ class BoundingBoxes(Datapoint):
canvas_size
=
canvas_size
if
canvas_size
is
not
None
else
other
.
canvas_size
,
canvas_size
=
canvas_size
if
canvas_size
is
not
None
else
other
.
canvas_size
,
)
)
@
classmethod
def
_wrap_output
(
cls
,
output
:
torch
.
Tensor
,
args
:
Sequence
[
Any
]
=
(),
kwargs
:
Optional
[
Mapping
[
str
,
Any
]]
=
None
,
)
->
BoundingBoxes
:
# If there are BoundingBoxes instances in the output, their metadata got lost when we called
# super().__torch_function__. We need to restore the metadata somehow, so we choose to take
# the metadata from the first bbox in the parameters.
# This should be what we want in most cases. When it's not, it's probably a mis-use anyway, e.g.
# something like some_xyxy_bbox + some_xywh_bbox; we don't guard against those cases.
flat_params
,
_
=
tree_flatten
(
args
+
(
tuple
(
kwargs
.
values
())
if
kwargs
else
()))
# type: ignore[operator]
first_bbox_from_args
=
next
(
x
for
x
in
flat_params
if
isinstance
(
x
,
BoundingBoxes
))
format
,
canvas_size
=
first_bbox_from_args
.
format
,
first_bbox_from_args
.
canvas_size
if
isinstance
(
output
,
torch
.
Tensor
)
and
not
isinstance
(
output
,
BoundingBoxes
):
output
=
BoundingBoxes
.
_wrap
(
output
,
format
=
format
,
canvas_size
=
canvas_size
,
check_dims
=
False
)
elif
isinstance
(
output
,
(
tuple
,
list
)):
output
=
type
(
output
)(
BoundingBoxes
.
_wrap
(
part
,
format
=
format
,
canvas_size
=
canvas_size
,
check_dims
=
False
)
for
part
in
output
)
return
output
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
(
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
)
return
self
.
_make_repr
(
format
=
self
.
format
,
canvas_size
=
self
.
canvas_size
)
torchvision/datapoints/_datapoint.py
View file @
88591717
...
@@ -6,6 +6,8 @@ import torch
...
@@ -6,6 +6,8 @@ import torch
from
torch._C
import
DisableTorchFunctionSubclass
from
torch._C
import
DisableTorchFunctionSubclass
from
torch.types
import
_device
,
_dtype
,
_size
from
torch.types
import
_device
,
_dtype
,
_size
from
torchvision.datapoints._torch_function_helpers
import
_FORCE_TORCHFUNCTION_SUBCLASS
,
_must_return_subclass
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
D
=
TypeVar
(
"D"
,
bound
=
"Datapoint"
)
...
@@ -33,9 +35,21 @@ class Datapoint(torch.Tensor):
...
@@ -33,9 +35,21 @@ class Datapoint(torch.Tensor):
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
return
tensor
.
as_subclass
(
cls
)
return
tensor
.
as_subclass
(
cls
)
# The ops in this set are those that should *preserve* the Datapoint type,
@
classmethod
# i.e. they are exceptions to the "no wrapping" rule.
def
_wrap_output
(
_NO_WRAPPING_EXCEPTIONS
=
{
torch
.
Tensor
.
clone
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
detach
,
torch
.
Tensor
.
requires_grad_
}
cls
,
output
:
torch
.
Tensor
,
args
:
Sequence
[
Any
]
=
(),
kwargs
:
Optional
[
Mapping
[
str
,
Any
]]
=
None
,
)
->
torch
.
Tensor
:
# Same as torch._tensor._convert
if
isinstance
(
output
,
torch
.
Tensor
)
and
not
isinstance
(
output
,
cls
):
output
=
output
.
as_subclass
(
cls
)
if
isinstance
(
output
,
(
tuple
,
list
)):
# Also handles things like namedtuples
output
=
type
(
output
)(
cls
.
_wrap_output
(
part
,
args
,
kwargs
)
for
part
in
output
)
return
output
@
classmethod
@
classmethod
def
__torch_function__
(
def
__torch_function__
(
...
@@ -60,7 +74,7 @@ class Datapoint(torch.Tensor):
...
@@ -60,7 +74,7 @@ class Datapoint(torch.Tensor):
2. For most operations, there is no way of knowing if the input type is still valid for the output.
2. For most operations, there is no way of knowing if the input type is still valid for the output.
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
For these reasons, the automatic output wrapping is turned off for most operators. The only exceptions are
listed in
:attr:`Datapoint._NO_WRAPPING_EXCEPTIONS`
listed in
_FORCE_TORCHFUNCTION_SUBCLASS
"""
"""
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# Since super().__torch_function__ has no hook to prevent the coercing of the output into the input type, we
# need to reimplement the functionality.
# need to reimplement the functionality.
...
@@ -68,19 +82,22 @@ class Datapoint(torch.Tensor):
...
@@ -68,19 +82,22 @@ class Datapoint(torch.Tensor):
if
not
all
(
issubclass
(
cls
,
t
)
for
t
in
types
):
if
not
all
(
issubclass
(
cls
,
t
)
for
t
in
types
):
return
NotImplemented
return
NotImplemented
# Like in the base Tensor.__torch_function__ implementation, it's easier to always use
# DisableTorchFunctionSubclass and then manually re-wrap the output if necessary
with
DisableTorchFunctionSubclass
():
with
DisableTorchFunctionSubclass
():
output
=
func
(
*
args
,
**
kwargs
or
dict
())
output
=
func
(
*
args
,
**
kwargs
or
dict
())
if
func
in
cls
.
_NO_WRAPPING_EXCEPTIONS
and
isinstance
(
args
[
0
],
cls
):
must_return_subclass
=
_must_return_subclass
()
if
must_return_subclass
or
(
func
in
_FORCE_TORCHFUNCTION_SUBCLASS
and
isinstance
(
args
[
0
],
cls
)):
# We also require the primary operand, i.e. `args[0]`, to be
# We also require the primary operand, i.e. `args[0]`, to be
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# an instance of the class that `__torch_function__` was invoked on. The __torch_function__ protocol will
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# invoke this method on *all* types involved in the computation by walking the MRO upwards. For example,
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `torch.Tensor(...).to(datapoints.Image(...))` will invoke `datapoints.Image.__torch_function__` with
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# `args = (torch.Tensor(), datapoints.Image())` first. Without this guard, the original `torch.Tensor` would
# be wrapped into a `datapoints.Image`.
# be wrapped into a `datapoints.Image`.
return
cls
.
wrap_
like
(
args
[
0
],
output
)
return
cls
.
_
wrap_
output
(
output
,
args
,
kwargs
)
if
isinstance
(
output
,
cls
):
if
not
must_return_subclass
and
isinstance
(
output
,
cls
):
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# DisableTorchFunctionSubclass is ignored by inplace ops like `.add_(...)`,
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
# so for those, the output is still a Datapoint. Thus, we need to manually unwrap.
return
output
.
as_subclass
(
torch
.
Tensor
)
return
output
.
as_subclass
(
torch
.
Tensor
)
...
...
torchvision/datapoints/_torch_function_helpers.py
0 → 100644
View file @
88591717
import
torch
_TORCHFUNCTION_SUBCLASS
=
False
class
_ReturnTypeCM
:
def
__init__
(
self
,
to_restore
):
self
.
to_restore
=
to_restore
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
args
):
global
_TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS
=
self
.
to_restore
def
set_return_type
(
return_type
:
str
):
"""Set the return type of torch operations on datapoints.
Can be used as a global flag for the entire program:
.. code:: python
set_return_type("datapoints")
img = datapoints.Image(torch.rand(3, 5, 5))
img + 2 # This is an Image
or as a context manager to restrict the scope:
.. code:: python
img = datapoints.Image(torch.rand(3, 5, 5))
with set_return_type("datapoints"):
img + 2 # This is an Image
img + 2 # This is a pure Tensor
Args:
return_type (str): Can be "datapoint" or "tensor". Default is "tensor".
"""
global
_TORCHFUNCTION_SUBCLASS
to_restore
=
_TORCHFUNCTION_SUBCLASS
_TORCHFUNCTION_SUBCLASS
=
{
"tensor"
:
False
,
"datapoint"
:
True
}[
return_type
.
lower
()]
return
_ReturnTypeCM
(
to_restore
)
def
_must_return_subclass
():
return
_TORCHFUNCTION_SUBCLASS
# For those ops we always want to preserve the original subclass instead of returning a pure Tensor
_FORCE_TORCHFUNCTION_SUBCLASS
=
{
torch
.
Tensor
.
clone
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
detach
,
torch
.
Tensor
.
requires_grad_
}
torchvision/transforms/v2/_misc.py
View file @
88591717
...
@@ -401,7 +401,7 @@ class SanitizeBoundingBoxes(Transform):
...
@@ -401,7 +401,7 @@ class SanitizeBoundingBoxes(Transform):
valid
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
valid
&=
(
boxes
[:,
0
]
<=
image_w
)
&
(
boxes
[:,
2
]
<=
image_w
)
valid
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
valid
&=
(
boxes
[:,
1
]
<=
image_h
)
&
(
boxes
[:,
3
]
<=
image_h
)
params
=
dict
(
valid
=
valid
,
labels
=
labels
)
params
=
dict
(
valid
=
valid
.
as_subclass
(
torch
.
Tensor
)
,
labels
=
labels
)
flat_outputs
=
[
flat_outputs
=
[
# Even-though it may look like we're transforming all inputs, we don't:
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxeses and the labels
# _transform() will only care about BoundingBoxeses and the labels
...
...
torchvision/transforms/v2/functional/_utils.py
View file @
88591717
...
@@ -19,6 +19,8 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
...
@@ -19,6 +19,8 @@ _KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def
_kernel_datapoint_wrapper
(
kernel
):
def
_kernel_datapoint_wrapper
(
kernel
):
@
functools
.
wraps
(
kernel
)
@
functools
.
wraps
(
kernel
)
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
def
wrapper
(
inpt
,
*
args
,
**
kwargs
):
# We always pass datapoints as pure tensors to the kernels to avoid going through the
# Tensor.__torch_function__ logic, which is costly.
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
output
=
kernel
(
inpt
.
as_subclass
(
torch
.
Tensor
),
*
args
,
**
kwargs
)
return
type
(
inpt
).
wrap_like
(
inpt
,
output
)
return
type
(
inpt
).
wrap_like
(
inpt
,
output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment