Unverified Commit 06ad7376 authored by Andrew Lingg's avatar Andrew Lingg Committed by GitHub
Browse files

Added binary cat vs dog classification target type to Oxford pet dataset (#8388)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent ff31b6e3
......@@ -2553,7 +2553,7 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(
split=("trainval", "test"),
target_types=("category", "segmentation", ["category", "segmentation"], []),
target_types=("category", "binary-category", "segmentation", ["category", "segmentation"], []),
)
def inject_fake_data(self, tmpdir, config):
......
......@@ -19,6 +19,7 @@ class OxfordIIITPet(VisionDataset):
``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
- ``category`` (int): Label for one of the 37 pet categories.
- ``binary-category`` (int): Binary label for cat or dog.
- ``segmentation`` (PIL image): Segmentation trimap of the image.
If empty, ``None`` will be returned as target.
......@@ -34,7 +35,7 @@ class OxfordIIITPet(VisionDataset):
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
)
_VALID_TARGET_TYPES = ("category", "segmentation")
_VALID_TARGET_TYPES = ("category", "binary-category", "segmentation")
def __init__(
self,
......@@ -67,12 +68,15 @@ class OxfordIIITPet(VisionDataset):
image_ids = []
self._labels = []
self._bin_labels = []
with open(self._anns_folder / f"{self._split}.txt") as file:
for line in file:
image_id, label, *_ = line.strip().split()
image_id, label, bin_label, _ = line.strip().split()
image_ids.append(image_id)
self._labels.append(int(label) - 1)
self._bin_labels.append(int(bin_label) - 1)
self.bin_classes = ["Cat", "Dog"]
self.classes = [
" ".join(part.title() for part in raw_cls.split("_"))
for raw_cls, _ in sorted(
......@@ -80,6 +84,7 @@ class OxfordIIITPet(VisionDataset):
key=lambda image_id_and_label: image_id_and_label[1],
)
]
self.bin_class_to_idx = dict(zip(self.bin_classes, range(len(self.bin_classes))))
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
......@@ -95,6 +100,8 @@ class OxfordIIITPet(VisionDataset):
for target_type in self._target_types:
if target_type == "category":
target.append(self._labels[idx])
elif target_type == "binary-category":
target.append(self._bin_labels[idx])
else: # target_type == "segmentation"
target.append(Image.open(self._segs[idx]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment