Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2bc8a14d
Unverified
Commit
2bc8a14d
authored
Jan 27, 2023
by
Philip Meier
Committed by
GitHub
Jan 27, 2023
Browse files
fix requires_grad passthrough (#7138)
parent
455eda68
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
28 additions
and
7 deletions
+28
-7
test/test_prototype_datapoints.py
test/test_prototype_datapoints.py
+19
-0
torchvision/prototype/datapoints/_bounding_box.py
torchvision/prototype/datapoints/_bounding_box.py
+1
-1
torchvision/prototype/datapoints/_datapoint.py
torchvision/prototype/datapoints/_datapoint.py
+4
-2
torchvision/prototype/datapoints/_image.py
torchvision/prototype/datapoints/_image.py
+1
-1
torchvision/prototype/datapoints/_label.py
torchvision/prototype/datapoints/_label.py
+1
-1
torchvision/prototype/datapoints/_mask.py
torchvision/prototype/datapoints/_mask.py
+1
-1
torchvision/prototype/datapoints/_video.py
torchvision/prototype/datapoints/_video.py
+1
-1
No files found.
test/test_prototype_datapoints.py
View file @
2bc8a14d
...
@@ -3,6 +3,25 @@ import torch
...
@@ -3,6 +3,25 @@ import torch
from
torchvision.prototype
import
datapoints
from
torchvision.prototype
import
datapoints
@
pytest
.
mark
.
parametrize
(
(
"data"
,
"input_requires_grad"
,
"expected_requires_grad"
),
[
([
0.0
],
None
,
False
),
([
0.0
],
False
,
False
),
([
0.0
],
True
,
True
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
False
),
None
,
False
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
False
),
False
,
False
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
False
),
True
,
True
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
True
),
None
,
True
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
True
),
False
,
False
),
(
torch
.
tensor
([
0.0
],
requires_grad
=
True
),
True
,
True
),
],
)
def
test_new_requires_grad
(
data
,
input_requires_grad
,
expected_requires_grad
):
datapoint
=
datapoints
.
Label
(
data
,
requires_grad
=
input_requires_grad
)
assert
datapoint
.
requires_grad
is
expected_requires_grad
def
test_isinstance
():
def
test_isinstance
():
assert
isinstance
(
assert
isinstance
(
datapoints
.
Label
([
0
,
1
,
0
],
categories
=
[
"foo"
,
"bar"
]),
datapoints
.
Label
([
0
,
1
,
0
],
categories
=
[
"foo"
,
"bar"
]),
...
...
torchvision/prototype/datapoints/_bounding_box.py
View file @
2bc8a14d
...
@@ -34,7 +34,7 @@ class BoundingBox(Datapoint):
...
@@ -34,7 +34,7 @@ class BoundingBox(Datapoint):
spatial_size
:
Tuple
[
int
,
int
],
spatial_size
:
Tuple
[
int
,
int
],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
BoundingBox
:
)
->
BoundingBox
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
...
...
torchvision/prototype/datapoints/_datapoint.py
View file @
2bc8a14d
...
@@ -23,8 +23,10 @@ class Datapoint(torch.Tensor):
...
@@ -23,8 +23,10 @@ class Datapoint(torch.Tensor):
data
:
Any
,
data
:
Any
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
requires_grad
is
None
:
requires_grad
=
data
.
requires_grad
if
isinstance
(
data
,
torch
.
Tensor
)
else
False
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
return
torch
.
as_tensor
(
data
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the Datapoint directly to have a
...
@@ -36,7 +38,7 @@ class Datapoint(torch.Tensor):
...
@@ -36,7 +38,7 @@ class Datapoint(torch.Tensor):
data
:
Any
,
data
:
Any
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
Datapoint
:
)
->
Datapoint
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
tensor
.
as_subclass
(
Datapoint
)
return
tensor
.
as_subclass
(
Datapoint
)
...
...
torchvision/prototype/datapoints/_image.py
View file @
2bc8a14d
...
@@ -21,7 +21,7 @@ class Image(Datapoint):
...
@@ -21,7 +21,7 @@ class Image(Datapoint):
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
Image
:
)
->
Image
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
tensor
.
ndim
<
2
:
if
tensor
.
ndim
<
2
:
...
...
torchvision/prototype/datapoints/_label.py
View file @
2bc8a14d
...
@@ -27,7 +27,7 @@ class _LabelBase(Datapoint):
...
@@ -27,7 +27,7 @@ class _LabelBase(Datapoint):
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
categories
:
Optional
[
Sequence
[
str
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
L
:
)
->
L
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
cls
.
_wrap
(
tensor
,
categories
=
categories
)
return
cls
.
_wrap
(
tensor
,
categories
=
categories
)
...
...
torchvision/prototype/datapoints/_mask.py
View file @
2bc8a14d
...
@@ -19,7 +19,7 @@ class Mask(Datapoint):
...
@@ -19,7 +19,7 @@ class Mask(Datapoint):
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
Mask
:
)
->
Mask
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
return
cls
.
_wrap
(
tensor
)
return
cls
.
_wrap
(
tensor
)
...
...
torchvision/prototype/datapoints/_video.py
View file @
2bc8a14d
...
@@ -20,7 +20,7 @@ class Video(Datapoint):
...
@@ -20,7 +20,7 @@ class Video(Datapoint):
*
,
*
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
device
:
Optional
[
Union
[
torch
.
device
,
str
,
int
]]
=
None
,
requires_grad
:
bool
=
Fals
e
,
requires_grad
:
Optional
[
bool
]
=
Non
e
,
)
->
Video
:
)
->
Video
:
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
tensor
=
cls
.
_to_tensor
(
data
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
if
data
.
ndim
<
4
:
if
data
.
ndim
<
4
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment