Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
d74404ad
Unverified
Commit
d74404ad
authored
May 19, 2020
by
Francisco Massa
Committed by
GitHub
May 19, 2020
Browse files
Make copy of targets in GeneralizedRCNNTransform (#2227)
* Make copy of targets in GeneralizedRCNNTransform * Fix flake8
parent
222a599e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
0 deletions
+23
-0
test/test_models_detection_utils.py
test/test_models_detection_utils.py
+11
-0
torchvision/models/detection/transform.py
torchvision/models/detection/transform.py
+12
-0
No files found.
test/test_models_detection_utils.py
View file @
d74404ad
import
copy
import
torch
import
torch
from
torchvision.models.detection
import
_utils
from
torchvision.models.detection
import
_utils
from
torchvision.models.detection.transform
import
GeneralizedRCNNTransform
import
unittest
import
unittest
from
torchvision.models.detection
import
fasterrcnn_resnet50_fpn
from
torchvision.models.detection
import
fasterrcnn_resnet50_fpn
...
@@ -33,6 +35,15 @@ class Tester(unittest.TestCase):
...
@@ -33,6 +35,15 @@ class Tester(unittest.TestCase):
# check that expected initial number of layers are frozen
# check that expected initial number of layers are frozen
self
.
assertTrue
(
all
(
is_frozen
[:
exp_froz_params
]))
self
.
assertTrue
(
all
(
is_frozen
[:
exp_froz_params
]))
def
test_transform_copy_targets
(
self
):
transform
=
GeneralizedRCNNTransform
(
300
,
500
,
torch
.
zeros
(
3
),
torch
.
ones
(
3
))
image
=
[
torch
.
rand
(
3
,
200
,
300
),
torch
.
rand
(
3
,
200
,
200
)]
targets
=
[{
'boxes'
:
torch
.
rand
(
3
,
4
)},
{
'boxes'
:
torch
.
rand
(
2
,
4
)}]
targets_copy
=
copy
.
deepcopy
(
targets
)
out
=
transform
(
image
,
targets
)
# noqa: F841
self
.
assertTrue
(
torch
.
equal
(
targets
[
0
][
'boxes'
],
targets_copy
[
0
][
'boxes'
]))
self
.
assertTrue
(
torch
.
equal
(
targets
[
1
][
'boxes'
],
targets_copy
[
1
][
'boxes'
]))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchvision/models/detection/transform.py
View file @
d74404ad
...
@@ -82,6 +82,18 @@ class GeneralizedRCNNTransform(nn.Module):
...
@@ -82,6 +82,18 @@ class GeneralizedRCNNTransform(nn.Module):
):
):
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
images
=
[
img
for
img
in
images
]
images
=
[
img
for
img
in
images
]
if
targets
is
not
None
:
# make a copy of targets to avoid modifying it in-place
# once torchscript supports dict comprehension
# this can be simplified as as follows
# targets = [{k: v for k,v in t.items()} for t in targets]
targets_copy
:
List
[
Dict
[
str
,
Tensor
]]
=
[]
for
t
in
targets
:
data
:
Dict
[
str
,
Tensor
]
=
{}
for
k
,
v
in
t
.
items
():
data
[
k
]
=
v
targets_copy
.
append
(
data
)
targets
=
targets_copy
for
i
in
range
(
len
(
images
)):
for
i
in
range
(
len
(
images
)):
image
=
images
[
i
]
image
=
images
[
i
]
target_index
=
targets
[
i
]
if
targets
is
not
None
else
None
target_index
=
targets
[
i
]
if
targets
is
not
None
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment