Unverified Commit 368d1c6f authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Updated tests for CopyPaste on OneHotLabel (#6485)

* [proto] Updated tests for CopyPaste on OneHotLabel

* Fixing test error
parent f82a4675
......@@ -1341,11 +1341,12 @@ class TestSimpleCopyPaste:
mocker.MagicMock(spec=features.SegmentationMask),
]
with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
def test__extract_image_targets(self, image_type, mocker):
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste()
flat_sample = [
......@@ -1353,11 +1354,11 @@ class TestSimpleCopyPaste:
self.create_fake_image(mocker, image_type),
self.create_fake_image(mocker, image_type),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
]
......@@ -1372,29 +1373,46 @@ class TestSimpleCopyPaste:
assert images[0] == flat_sample[0]
assert images[1] == flat_sample[1]
def test__copy_paste(self):
for target in targets:
for key, type_ in [
("boxes", features.BoundingBox),
("masks", features.SegmentationMask),
("labels", label_type),
]:
assert key in target
assert isinstance(target[key], type_)
assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [features.Label, features.OneHotLabel])
def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
masks[0, 3:9, 2:8] = 1
masks[1, 20:30, 20:30] = 1
labels = torch.tensor([1, 2])
if label_type == features.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
"boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(masks),
"labels": features.Label(torch.tensor([1, 2])),
"labels": label_type(labels),
}
paste_image = 10 * torch.ones(3, 32, 32)
paste_masks = torch.zeros(2, 32, 32)
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4])
if label_type == features.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = {
"boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(paste_masks),
"labels": features.Label(torch.tensor([3, 4])),
"labels": label_type(paste_labels),
}
transform = transforms.SimpleCopyPaste()
......@@ -1405,7 +1423,12 @@ class TestSimpleCopyPaste:
assert output_target["boxes"].shape == (4, 4)
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == features.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
assert output_target["masks"].shape == (4, 32, 32)
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])
......
......@@ -288,7 +288,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment