Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
9b82df43
Unverified
Commit
9b82df43
authored
Aug 07, 2023
by
Nicolas Hug
Committed by
GitHub
Aug 07, 2023
Browse files
Remove `_wrap()` class method from base class Datapoint (#7805)
parent
2030d208
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
29 additions
and
18 deletions
+29
-18
test/test_datapoints.py
test/test_datapoints.py
+20
-0
torchvision/datapoints/_bounding_box.py
torchvision/datapoints/_bounding_box.py
+4
-9
torchvision/datapoints/_datapoint.py
torchvision/datapoints/_datapoint.py
+1
-5
torchvision/datapoints/_image.py
torchvision/datapoints/_image.py
+1
-1
torchvision/datapoints/_mask.py
torchvision/datapoints/_mask.py
+1
-1
torchvision/datapoints/_video.py
torchvision/datapoints/_video.py
+1
-1
torchvision/prototype/datapoints/_label.py
torchvision/prototype/datapoints/_label.py
+1
-1
No files found.
test/test_datapoints.py
View file @
9b82df43
...
@@ -113,6 +113,26 @@ def test_detach_wrapping():
...
@@ -113,6 +113,26 @@ def test_detach_wrapping():
assert
type
(
image_detached
)
is
datapoints
.
Image
assert
type
(
image_detached
)
is
datapoints
.
Image
def
test_no_wrapping_exceptions_with_metadata
():
# Sanity checks for the ops in _NO_WRAPPING_EXCEPTIONS and datapoints with metadata
format
,
canvas_size
=
"XYXY"
,
(
32
,
32
)
bbox
=
datapoints
.
BoundingBoxes
([[
0
,
0
,
5
,
5
],
[
2
,
2
,
7
,
7
]],
format
=
format
,
canvas_size
=
canvas_size
)
bbox
=
bbox
.
clone
()
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
to
(
torch
.
float64
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
bbox
=
bbox
.
detach
()
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
not
bbox
.
requires_grad
bbox
.
requires_grad_
(
True
)
assert
bbox
.
format
,
bbox
.
canvas_size
==
(
format
,
canvas_size
)
assert
bbox
.
requires_grad
def
test_other_op_no_wrapping
():
def
test_other_op_no_wrapping
():
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
image
=
datapoints
.
Image
(
torch
.
rand
(
3
,
16
,
16
))
...
...
torchvision/datapoints/_bounding_box.py
View file @
9b82df43
...
@@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint):
...
@@ -42,7 +42,9 @@ class BoundingBoxes(Datapoint):
canvas_size
:
Tuple
[
int
,
int
]
canvas_size
:
Tuple
[
int
,
int
]
@
classmethod
@
classmethod
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
BoundingBoxFormat
,
canvas_size
:
Tuple
[
int
,
int
])
->
BoundingBoxes
:
# type: ignore[override]
def
_wrap
(
cls
,
tensor
:
torch
.
Tensor
,
*
,
format
:
Union
[
BoundingBoxFormat
,
str
],
canvas_size
:
Tuple
[
int
,
int
])
->
BoundingBoxes
:
# type: ignore[override]
if
isinstance
(
format
,
str
):
format
=
BoundingBoxFormat
[
format
.
upper
()]
bounding_boxes
=
tensor
.
as_subclass
(
cls
)
bounding_boxes
=
tensor
.
as_subclass
(
cls
)
bounding_boxes
.
format
=
format
bounding_boxes
.
format
=
format
bounding_boxes
.
canvas_size
=
canvas_size
bounding_boxes
.
canvas_size
=
canvas_size
...
@@ -59,10 +61,6 @@ class BoundingBoxes(Datapoint):
...
@@ -59,10 +61,6 @@ class BoundingBoxes(Datapoint):
requires_grad
:
Optional
[
bool
]
=
None
,
requires_grad
:
Optional
[
bool
]
=
None
,
)
->
BoundingBoxes
:
)
->
BoundingBoxes
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
isinstance
(
format
,
str
):
format
=
BoundingBoxFormat
[
format
.
upper
()]
return
cls
.
_wrap
(
tensor
,
format
=
format
,
canvas_size
=
canvas_size
)
return
cls
.
_wrap
(
tensor
,
format
=
format
,
canvas_size
=
canvas_size
)
@
classmethod
@
classmethod
...
@@ -71,7 +69,7 @@ class BoundingBoxes(Datapoint):
...
@@ -71,7 +69,7 @@ class BoundingBoxes(Datapoint):
other
:
BoundingBoxes
,
other
:
BoundingBoxes
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
*
,
*
,
format
:
Optional
[
BoundingBoxFormat
]
=
None
,
format
:
Optional
[
Union
[
BoundingBoxFormat
,
str
]
]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
canvas_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
BoundingBoxes
:
)
->
BoundingBoxes
:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
...
@@ -85,9 +83,6 @@ class BoundingBoxes(Datapoint):
...
@@ -85,9 +83,6 @@ class BoundingBoxes(Datapoint):
omitted, it is taken from the reference.
omitted, it is taken from the reference.
"""
"""
if
isinstance
(
format
,
str
):
format
=
BoundingBoxFormat
[
format
.
upper
()]
return
cls
.
_wrap
(
return
cls
.
_wrap
(
tensor
,
tensor
,
format
=
format
if
format
is
not
None
else
other
.
format
,
format
=
format
if
format
is
not
None
else
other
.
format
,
...
...
torchvision/datapoints/_datapoint.py
View file @
9b82df43
...
@@ -32,13 +32,9 @@ class Datapoint(torch.Tensor):
...
@@ -32,13 +32,9 @@ class Datapoint(torch.Tensor):
requires_grad
=
data
.
requires_grad
if
isinstance
(
data
,
torch
.
Tensor
)
else
False
requires_grad
=
data
.
requires_grad
if
isinstance
(
data
,
torch
.
Tensor
)
else
False
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
@
classmethod
def
_wrap
(
cls
:
Type
[
D
],
tensor
:
torch
.
Tensor
)
->
D
:
return
tensor
.
as_subclass
(
cls
)
@
classmethod
@
classmethod
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
def
wrap_like
(
cls
:
Type
[
D
],
other
:
D
,
tensor
:
torch
.
Tensor
)
->
D
:
return
cls
.
_wrap
(
tensor
)
return
tensor
.
as_subclass
(
cls
)
_NO_WRAPPING_EXCEPTIONS
=
{
_NO_WRAPPING_EXCEPTIONS
=
{
torch
.
Tensor
.
clone
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
torch
.
Tensor
.
clone
:
lambda
cls
,
input
,
output
:
cls
.
wrap_like
(
input
,
output
),
...
...
torchvision/datapoints/_image.py
View file @
9b82df43
...
@@ -41,7 +41,7 @@ class Image(Datapoint):
...
@@ -41,7 +41,7 @@ class Image(Datapoint):
elif
tensor
.
ndim
==
2
:
elif
tensor
.
ndim
==
2
:
tensor
=
tensor
.
unsqueeze
(
0
)
tensor
=
tensor
.
unsqueeze
(
0
)
return
cls
.
_wrap
(
tensor
)
return
tensor
.
as_subclass
(
cls
)
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
return
self
.
_make_repr
()
...
...
torchvision/datapoints/_mask.py
View file @
9b82df43
...
@@ -36,4 +36,4 @@ class Mask(Datapoint):
...
@@ -36,4 +36,4 @@ class Mask(Datapoint):
data
=
F
.
pil_to_tensor
(
data
)
data
=
F
.
pil_to_tensor
(
data
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
cls
.
_wrap
(
tensor
)
return
tensor
.
as_subclass
(
cls
)
torchvision/datapoints/_video.py
View file @
9b82df43
...
@@ -31,7 +31,7 @@ class Video(Datapoint):
...
@@ -31,7 +31,7 @@ class Video(Datapoint):
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
data
.
ndim
<
4
:
if
data
.
ndim
<
4
:
raise
ValueError
raise
ValueError
return
cls
.
_wrap
(
tensor
)
return
tensor
.
as_subclass
(
cls
)
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
def
__repr__
(
self
,
*
,
tensor_contents
:
Any
=
None
)
->
str
:
# type: ignore[override]
return
self
.
_make_repr
()
return
self
.
_make_repr
()
...
...
torchvision/prototype/datapoints/_label.py
View file @
9b82df43
...
@@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
...
@@ -15,7 +15,7 @@ class _LabelBase(Datapoint):
categories
:
Optional
[
Sequence
[
str
]]
categories
:
Optional
[
Sequence
[
str
]]
@
classmethod
@
classmethod
def
_wrap
(
cls
:
Type
[
L
],
tensor
:
torch
.
Tensor
,
*
,
categories
:
Optional
[
Sequence
[
str
]])
->
L
:
# type: ignore[override]
def
_wrap
(
cls
:
Type
[
L
],
tensor
:
torch
.
Tensor
,
*
,
categories
:
Optional
[
Sequence
[
str
]])
->
L
:
label_base
=
tensor
.
as_subclass
(
cls
)
label_base
=
tensor
.
as_subclass
(
cls
)
label_base
.
categories
=
categories
label_base
.
categories
=
categories
return
label_base
return
label_base
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment