Unverified Commit a6f3f95a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for Coco (#3416)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent ccb7f45a
......@@ -21,6 +21,7 @@ import pickle
from torchvision import datasets
import torch
import shutil
import json
try:
......@@ -839,5 +840,70 @@ class VOCDetectionTestCase(VOCSegmentationTestCase):
self.assertEqual(object, info["annotation"])
class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CocoDetection
FEATURE_TYPES = (PIL.Image.Image, list)
REQUIRED_PACKAGES = ("pycocotools",)
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
num_images = 3
num_annotations_per_image = 2
image_folder = tmpdir / "images"
files = datasets_utils.create_image_folder(
tmpdir, name="images", file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_images
)
file_names = [file.relative_to(image_folder) for file in files]
annotation_folder = tmpdir / "annotations"
os.makedirs(annotation_folder)
annotation_file, info = self._create_annotation_file(annotation_folder, file_names, num_annotations_per_image)
info["num_examples"] = num_images
return (str(image_folder), str(annotation_file)), info
def _create_annotation_file(self, root, file_names, num_annotations_per_image):
image_ids = [int(file_name.stem) for file_name in file_names]
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]
annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
content = dict(images=images, annotations=annotations)
return self._create_json(root, "annotations.json", content), info
def _create_annotations(self, image_ids, num_annotations_per_image):
annotations = datasets_utils.combinations_grid(
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
)
for id, annotation in enumerate(annotations):
annotation["id"] = id
return annotations, dict()
def _create_json(self, root, name, content):
file = pathlib.Path(root) / name
with open(file, "w") as fh:
json.dump(content, fh)
return file
class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions
def _create_annotations(self, image_ids, num_annotations_per_image):
captions = [str(idx) for idx in range(num_annotations_per_image)]
annotations = datasets_utils.combinations_grid(image_id=image_ids, caption=captions)
for id, annotation in enumerate(annotations):
annotation["id"] = id
return annotations, dict(captions=captions)
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertEqual(tuple(captions), tuple(info["captions"]))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment