Unverified Commit 368d1c6f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Updated tests for CopyPaste on OneHotLabel (#6485)

* [proto] Updated tests for CopyPaste on OneHotLabel

* Fixing test error
parent f82a4675
...@@ -1341,11 +1341,12 @@ class TestSimpleCopyPaste: ...@@ -1341,11 +1341,12 @@ class TestSimpleCopyPaste:
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.SegmentationMask),
] ]
with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"): with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
transform._extract_image_targets(flat_sample) transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor]) @pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker): @pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste() transform = transforms.SimpleCopyPaste()
flat_sample = [ flat_sample = [
...@@ -1353,11 +1354,11 @@ class TestSimpleCopyPaste: ...@@ -1353,11 +1354,11 @@ class TestSimpleCopyPaste:
self.create_fake_image(mocker, image_type), self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type), self.create_fake_image(mocker, image_type),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=features.Label), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=features.Label), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.SegmentationMask),
] ]
...@@ -1372,29 +1373,46 @@ class TestSimpleCopyPaste: ...@@ -1372,29 +1373,46 @@ class TestSimpleCopyPaste:
assert images[0] == flat_sample[0] assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1] assert images[1] == flat_sample[1]
def test__copy_paste(self): for target in targets:
for key, type_ in [
("boxes", features.BoundingBox),
("masks", features.SegmentationMask),
("labels", label_type),
]:
assert key in target
assert isinstance(target[key], type_)
assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel])
def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32) image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32) masks = torch.zeros(2, 32, 32)
masks[0, 3:9, 2:8] = 1 masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1 masks[1, 20:30, 20:30] = 1
labels = torch.tensor([1, 2])
if label_type == features.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": features.BoundingBox( "boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
), ),
"masks": features.SegmentationMask(masks), "masks": features.SegmentationMask(masks),
"labels": features.Label(torch.tensor([1, 2])), "labels": label_type(labels),
} }
paste_image = 10 * torch.ones(3, 32, 32) paste_image = 10 * torch.ones(3, 32, 32)
paste_masks = torch.zeros(2, 32, 32) paste_masks = torch.zeros(2, 32, 32)
paste_masks[0, 13:19, 12:18] = 1 paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1 paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4])
if label_type == features.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": features.BoundingBox( "boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
), ),
"masks": features.SegmentationMask(paste_masks), "masks": features.SegmentationMask(paste_masks),
"labels": features.Label(torch.tensor([3, 4])), "labels": label_type(paste_labels),
} }
transform = transforms.SimpleCopyPaste() transform = transforms.SimpleCopyPaste()
...@@ -1405,7 +1423,12 @@ class TestSimpleCopyPaste: ...@@ -1405,7 +1423,12 @@ class TestSimpleCopyPaste:
assert output_target["boxes"].shape == (4, 4) assert output_target["boxes"].shape == (4, 4)
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"]) torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == features.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
assert output_target["masks"].shape == (4, 32, 32) assert output_target["masks"].shape == (4, 32, 32)
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"]) torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"]) torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])
......
...@@ -288,7 +288,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -288,7 +288,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
if not (len(images) == len(bboxes) == len(masks) == len(labels)): if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, " f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." "BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment