Unverified Commit c7c2085e authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Bugfix in BalancedPositiveNegativeSampler introduced during torchscript support (#1670)

parent bce17fdd
import torch
from torchvision.models.detection import _utils
import unittest
class Tester(unittest.TestCase):
def test_balanced_positive_negative_sampler(self):
sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25)
# keep all 6 negatives first, then add 3 positives, last two are ignore
matched_idxs = [torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1])]
pos, neg = sampler(matched_idxs)
# we know the number of elements that should be sampled for the positive (1)
# and the negative (3), and their location. Let's make sure that they are
# there
self.assertEqual(pos[0].sum(), 1)
self.assertEqual(pos[0][6:9].sum(), 1)
self.assertEqual(neg[0].sum(), 3)
self.assertEqual(neg[0][0:6].sum(), 3)
if __name__ == '__main__':
unittest.main()
......@@ -11,10 +11,8 @@ import torchvision
# TODO: https://github.com/pytorch/pytorch/issues/26727
def zeros_like(tensor, dtype):
# type: (Tensor, int) -> Tensor
if tensor.dtype == dtype:
return tensor.detach().clone()
else:
return tensor.to(dtype)
return torch.zeros_like(tensor, dtype=dtype, layout=tensor.layout,
device=tensor.device, pin_memory=tensor.is_pinned())
@torch.jit.script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment