Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c7c2085e
Unverified
Commit
c7c2085e
authored
Dec 16, 2019
by
Francisco Massa
Committed by
GitHub
Dec 16, 2019
Browse files
Bugfix in BalancedPositiveNegativeSampler introduced during torchscript support (#1670)
parent
bce17fdd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
4 deletions
+24
-4
test/test_models_detection_utils.py
test/test_models_detection_utils.py
+22
-0
torchvision/models/detection/_utils.py
torchvision/models/detection/_utils.py
+2
-4
No files found.
test/test_models_detection_utils.py
0 → 100644
View file @
c7c2085e
import
torch
from
torchvision.models.detection
import
_utils
import
unittest
class
Tester
(
unittest
.
TestCase
):
def
test_balanced_positive_negative_sampler
(
self
):
sampler
=
_utils
.
BalancedPositiveNegativeSampler
(
4
,
0.25
)
# keep all 6 negatives first, then add 3 positives, last two are ignore
matched_idxs
=
[
torch
.
tensor
([
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
-
1
,
-
1
])]
pos
,
neg
=
sampler
(
matched_idxs
)
# we know the number of elements that should be sampled for the positive (1)
# and the negative (3), and their location. Let's make sure that they are
# there
self
.
assertEqual
(
pos
[
0
].
sum
(),
1
)
self
.
assertEqual
(
pos
[
0
][
6
:
9
].
sum
(),
1
)
self
.
assertEqual
(
neg
[
0
].
sum
(),
3
)
self
.
assertEqual
(
neg
[
0
][
0
:
6
].
sum
(),
3
)
if
__name__
==
'__main__'
:
unittest
.
main
()
torchvision/models/detection/_utils.py
View file @
c7c2085e
...
...
@@ -11,10 +11,8 @@ import torchvision
# TODO: https://github.com/pytorch/pytorch/issues/26727
def
zeros_like
(
tensor
,
dtype
):
# type: (Tensor, int) -> Tensor
if
tensor
.
dtype
==
dtype
:
return
tensor
.
detach
().
clone
()
else
:
return
tensor
.
to
(
dtype
)
return
torch
.
zeros_like
(
tensor
,
dtype
=
dtype
,
layout
=
tensor
.
layout
,
device
=
tensor
.
device
,
pin_memory
=
tensor
.
is_pinned
())
@
torch
.
jit
.
script
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment