Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d814772e
Unverified
Commit
d814772e
authored
Jun 27, 2023
by
Philip Meier
Committed by
GitHub
Jun 27, 2023
Browse files
make datapoints deepcopyable (#7701)
parent
357a40f1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
165 additions
and
134 deletions
+165
-134
test/test_datapoints.py
test/test_datapoints.py
+154
-0
test/test_prototype_datapoints.py
test/test_prototype_datapoints.py
+0
-133
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+11
-1
No files found.
test/test_datapoints.py
View file @
d814772e
from
copy
import
deepcopy
import
pytest
import
torch
from
common_utils
import
assert_equal
from
PIL
import
Image
from
torchvision
import
datapoints
...
...
@@ -30,3 +33,154 @@ def test_bbox_instance(data, format):
if
isinstance
(
format
,
str
):
format
=
datapoints
.
BoundingBoxFormat
[(
format
.
upper
())]
assert
bboxes
.
format
==
format
@
pytest
.
mark
.
parametrize
(
(
"data"
,
"input_requires_grad"
,
"expected_requires_grad"
),
[
([[[
0.0
,
1.0
],
[
0.0
,
1.0
]]],
None
,
False
),
([[[
0.0
,
1.0
],
[
0.0
,
1.0
]]],
False
,
False
),
([[[
0.0
,
1.0
],
[
0.0
,
1.0
]]],
True
,
True
),
(
torch
.
rand
(
3
,
16
,
16
,
requires_grad
=
False
),
None
,
False
),
(
torch
.
rand
(
3
,
16
,
16
,
requires_grad
=
False
),
False
,
False
),
(
torch
.
rand
(
3
,
16
,
16
,
requires_grad
=
False
),
True
,
True
),
(
torch
.
rand
(
3
,
16
,
16
,
requires_grad
=
True
),
None
,
True
),
(
torch
.
rand
(
3
,
16
,
16
,
requires_grad
=
True
),
False
,
False
),
(
torch
.
rand
(
3
,
16
,
16
,
requires_grad
=
True
),
True
,
True
),
],
)
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
datapoint
=
datapoints
.
Image
(
data
,
requires_grad
=
input_requires_grad
)
assert
datapoint
.
requires_grad
is
expected_requires_grad
def
test_isinstance
():
assert
isinstance
(
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
)),
torch
.
Tensor
)
def
test_wrapping_no_copy
():
tensor
=
torch
.
rand
(
3
,
16
,
16
)
image
=
datapoints
.
Image
(
tensor
)
assert
image
.
data_ptr
()
==
tensor
.
data_ptr
()
def
test_to_wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image_to
=
image
.
to
(
torch
.
float64
)
assert
type
(
image_to
)
is
datapoints
.
Image
assert
image_to
.
dtype
is
torch
.
float64
def
test_to_datapoint_reference
():
tensor
=
torch
.
rand
((
3
,
16
,
16
),
dtype
=
torch
.
float64
)
image
=
datapoints
.
Image
(
tensor
)
tensor_to
=
tensor
.
to
(
image
)
assert
type
(
tensor_to
)
is
torch
.
Tensor
assert
tensor_to
.
dtype
is
torch
.
float64
def
test_clone_wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image_clone
=
image
.
clone
()
assert
type
(
image_clone
)
is
datapoints
.
Image
assert
image_clone
.
data_ptr
()
!=
image
.
data_ptr
()
def
test_requires_grad__wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
assert
not
image
.
requires_grad
image_requires_grad
=
image
.
requires_grad_
(
True
)
assert
type
(
image_requires_grad
)
is
datapoints
.
Image
assert
image
.
requires_grad
assert
image_requires_grad
.
requires_grad
def
test_detach_wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
),
requires_grad
=
True
)
image_detached
=
image
.
detach
()
assert
type
(
image_detached
)
is
datapoints
.
Image
def
test_other_op_no_wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
output
=
image
*
2
assert
type
(
output
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"op"
,
[
lambda
t
:
t
.
numpy
(),
lambda
t
:
t
.
tolist
(),
lambda
t
:
t
.
max
(
dim
=-
1
),
],
)
def
test_no_tensor_output_op_no_wrapping
(
op
):
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
output
=
op
(
image
)
assert
type
(
output
)
is
not
datapoints
.
Image
def
test_inplace_op_no_wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
output
=
image
.
add_
(
0
)
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
image
)
is
datapoints
.
Image
def
test_wrap_like
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
# any operation besides the ones listed in `Datapoint._NO_WRAPPING_EXCEPTIONS` will do here
output
=
image
*
2
image_new
=
datapoints
.
Image
.
wrap_like
(
image
,
output
)
assert
type
(
image_new
)
is
datapoints
.
Image
assert
image_new
.
data_ptr
()
==
output
.
data_ptr
()
@
pytest
.
mark
.
parametrize
(
"datapoint"
,
[
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
)),
datapoints
.
Video
(
torch
.
rand
(
2
,
3
,
16
,
16
)),
datapoints
.
BoundingBox
([
0.0
,
1.0
,
2.0
,
3.0
],
format
=
datapoints
.
BoundingBoxFormat
.
XYXY
,
spatial_size
=
(
10
,
10
)),
datapoints
.
Mask
(
torch
.
randint
(
0
,
256
,
(
16
,
16
),
dtype
=
torch
.
uint8
)),
],
)
@
pytest
.
mark
.
parametrize
(
"requires_grad"
,
[
False
,
True
])
def
test_deepcopy
(
datapoint
,
requires_grad
):
if
requires_grad
and
not
datapoint
.
dtype
.
is_floating_point
:
return
datapoint
.
requires_grad_
(
requires_grad
)
datapoint_deepcopied
=
deepcopy
(
datapoint
)
assert
datapoint_deepcopied
is
not
datapoint
assert
datapoint_deepcopied
.
data_ptr
()
!=
datapoint
.
data_ptr
()
assert_equal
(
datapoint_deepcopied
,
datapoint
)
assert
type
(
datapoint_deepcopied
)
is
type
(
datapoint
)
assert
datapoint_deepcopied
.
requires_grad
is
requires_grad
assert
datapoint_deepcopied
.
is_leaf
test/test_prototype_datapoints.py
deleted
100644 → 0
View file @
357a40f1
import
pytest
import
torch
from
torchvision.prototype
import
datapoints
as
proto_datapoints
@
pytest
.
mark
.
parametrize
(
(
"data"
,
"input_requires_grad"
,
"expected_requires_grad"
),
[
([
0.0
],
None
,
False
),
([
0.0
],
False
,
False
),
([
0.0
],
True
,
True
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
False
),
None
,
False
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
False
),
False
,
False
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
False
),
True
,
True
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
True
),
None
,
True
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
True
),
False
,
False
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
True
),
True
,
True
),
],
)
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
datapoint
=
proto_datapoints
.
Label
(
data
,
requires_grad
=
input_requires_grad
)
assert
datapoint
.
requires_grad
is
expected_requires_grad
def
test_isinstance
():
assert
isinstance
(
proto_datapoints
.
Label
([
0
,
1
,
0
],
categories
=
[
"foo"
,
"bar"
]),
torch
.
Tensor
,
)
def
test_wrapping_no_copy
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
assert
label
.
data_ptr
()
==
tensor
.
data_ptr
()
def
test_to_wrapping
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
label_to
=
label
.
to
(
torch
.
int32
)
assert
type
(
label_to
)
is
proto_datapoints
.
Label
assert
label_to
.
dtype
is
torch
.
int32
assert
label_to
.
categories
is
label
.
categories
def
test_to_datapoint_reference
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
]).
to
(
torch
.
int32
)
tensor_to
=
tensor
.
to
(
label
)
assert
type
(
tensor_to
)
is
torch
.
Tensor
assert
tensor_to
.
dtype
is
torch
.
int32
def
test_clone_wrapping
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
label_clone
=
label
.
clone
()
assert
type
(
label_clone
)
is
proto_datapoints
.
Label
assert
label_clone
.
data_ptr
()
!=
label
.
data_ptr
()
assert
label_clone
.
categories
is
label
.
categories
def
test_requires_grad__wrapping
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
float32
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
assert
not
label
.
requires_grad
label_requires_grad
=
label
.
requires_grad_
(
True
)
assert
type
(
label_requires_grad
)
is
proto_datapoints
.
Label
assert
label
.
requires_grad
assert
label_requires_grad
.
requires_grad
def
test_other_op_no_wrapping
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
# any operation besides .to() and .clone() will do here
output
=
label
*
2
assert
type
(
output
)
is
torch
.
Tensor
@
pytest
.
mark
.
parametrize
(
"op"
,
[
lambda
t
:
t
.
numpy
(),
lambda
t
:
t
.
tolist
(),
lambda
t
:
t
.
max
(
dim
=-
1
),
],
)
def
test_no_tensor_output_op_no_wrapping
(
op
):
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
output
=
op
(
label
)
assert
type
(
output
)
is
not
proto_datapoints
.
Label
def
test_inplace_op_no_wrapping
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
output
=
label
.
add_
(
0
)
assert
type
(
output
)
is
torch
.
Tensor
assert
type
(
label
)
is
proto_datapoints
.
Label
def
test_wrap_like
():
tensor
=
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
int64
)
label
=
proto_datapoints
.
Label
(
tensor
,
categories
=
[
"foo"
,
"bar"
])
# any operation besides .to() and .clone() will do here
output
=
label
*
2
label_new
=
proto_datapoints
.
Label
.
wrap_like
(
label
,
output
)
assert
type
(
label_new
)
is
proto_datapoints
.
Label
assert
label_new
.
data_ptr
()
==
output
.
data_ptr
()
assert
label_new
.
categories
is
label
.
categories
torchvision/datapoints/_datapoint.py
View file @
d814772e
from
__future__
import
annotations
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
List
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypeVar
,
Union
import
PIL.Image
import
torch
...
...
@@ -36,6 +36,7 @@ class Datapoint(torch.Tensor):
_NO_WRAPPING_EXCEPTIONS
=
{
torch
.
Tensor
.
clone
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
torch
.
Tensor
.
to
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
torch
.
Tensor
.
detach
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically
torch
.
Tensor
.
requires_grad_
:
lambda
cls
,
input
,
output
:
output
,
...
...
@@ -132,6 +133,15 @@ class Datapoint(torch.Tensor):
with
DisableTorchFunctionSubclass
():
return
super
().
dtype
def
__deepcopy__
(
self
:
D
,
memo
:
Dict
[
int
,
Any
])
->
D
:
# We need to detach first, since a plain `Tensor.clone` will be part of the computation graph, which does
# *not* happen for `deepcopy(Tensor)`. A side-effect from detaching is that the `Tensor.requires_grad`
# attribute is cleared, so we need to refill it before we return.
# Note: We don't explicitly handle deep-copying of the metadata here. The only metadata we currently have is
# `BoundingBox.format` and `BoundingBox.spatial_size`, which are immutable and thus implicitly deep-copied by
# `BoundingBox.clone()`.
return
self
.
detach
().
clone
().
requires_grad_
(
self
.
requires_grad
)
# type: ignore[return-value]
def
horizontal_flip
(
self
)
->
Datapoint
:
return
self
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment