Unverified Commit 01398088 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for Flickr(8|30)k datasets (#3489)



* add tests for Flickr8k dataset

* add tests for FLickr30k dataset

* lint
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 8fe439e6
......@@ -1354,5 +1354,92 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
self.FEATURE_TYPES = feature_types
class Flickr8kTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Flickr8k
FEATURE_TYPES = (PIL.Image.Image, list)
_IMAGES_FOLDER = "images"
_ANNOTATIONS_FILE = "captions.html"
def dataset_args(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
root = tmpdir / self._IMAGES_FOLDER
ann_file = tmpdir / self._ANNOTATIONS_FILE
return str(root), str(ann_file)
def inject_fake_data(self, tmpdir, config):
num_images = 3
num_captions_per_image = 3
tmpdir = pathlib.Path(tmpdir)
images = self._create_images(tmpdir, self._IMAGES_FOLDER, num_images)
self._create_annotations_file(tmpdir, self._ANNOTATIONS_FILE, images, num_captions_per_image)
return dict(num_examples=num_images, captions=self._create_captions(num_captions_per_image))
def _create_images(self, root, name, num_images):
return datasets_utils.create_image_folder(root, name, self._image_file_name, num_images)
def _image_file_name(self, idx):
id = datasets_utils.create_random_string(10, string.digits)
checksum = datasets_utils.create_random_string(10, string.digits, string.ascii_lowercase[:6])
size = datasets_utils.create_random_string(1, "qwcko")
return f"{id}_{checksum}_{size}.jpg"
def _create_annotations_file(self, root, name, images, num_captions_per_image):
with open(root / name, "w") as fh:
fh.write("<table>")
for image in (None, *images):
self._add_image(fh, image, num_captions_per_image)
fh.write("</table>")
def _add_image(self, fh, image, num_captions_per_image):
fh.write("<tr>")
self._add_image_header(fh, image)
fh.write("</tr><tr><td><ul>")
self._add_image_captions(fh, num_captions_per_image)
fh.write("</ul></td></tr>")
def _add_image_header(self, fh, image=None):
if image:
url = f"http://www.flickr.com/photos/user/{image.name.split('_')[0]}/"
data = f'<a href="{url}">{url}</a>'
else:
data = "Image Not Found"
fh.write(f"<td>{data}</td>")
def _add_image_captions(self, fh, num_captions_per_image):
for caption in self._create_captions(num_captions_per_image):
fh.write(f"<li>{caption}")
def _create_captions(self, num_captions_per_image):
return [str(idx) for idx in range(num_captions_per_image)]
def test_captions(self):
with self.create_dataset() as (dataset, info):
_, captions = dataset[0]
self.assertSequenceEqual(captions, info["captions"])
class Flickr30kTestCase(Flickr8kTestCase):
DATASET_CLASS = datasets.Flickr30k
FEATURE_TYPES = (PIL.Image.Image, list)
_ANNOTATIONS_FILE = "captions.token"
def _image_file_name(self, idx):
return f"{idx}.jpg"
def _create_annotations_file(self, root, name, images, num_captions_per_image):
with open(root / name, "w") as fh:
for image, (idx, caption) in itertools.product(
images, enumerate(self._create_captions(num_captions_per_image))
):
fh.write(f"{image.name}#{idx}\t{caption}\n")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment