Unverified Commit f7d9e75b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Support encoded RLE format in for COCO segmentations (#8387)

parent 26af015a
......@@ -782,32 +782,46 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER
os.makedirs(annotation_folder)
segmentation_kind = config.pop("segmentation_kind", "list")
info = self._create_annotation_file(
annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image
annotation_folder,
self._ANNOTATIONS_FILE,
file_names,
num_annotations_per_image,
segmentation_kind=segmentation_kind,
)
info["num_examples"] = num_images
return info
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image):
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image, segmentation_kind="list"):
image_ids = [int(file_name.stem) for file_name in file_names]
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]
annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
annotations, info = self._create_annotations(image_ids, num_annotations_per_image, segmentation_kind)
self._create_json(root, name, dict(images=images, annotations=annotations))
return info
def _create_annotations(self, image_ids, num_annotations_per_image):
def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"):
annotations = []
annotion_id = 0
for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
segmentation = {
"list": [torch.rand(8).tolist()],
"rle": {"size": [10, 10], "counts": [1]},
"rle_encoded": {"size": [2400, 2400], "counts": "PQRQ2[1\\Y2f0gNVNRhMg2"},
"bad": 123,
}[segmentation_kind]
annotations.append(
dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
segmentation=segmentation,
category_id=int(torch.randint(91, ())),
area=float(torch.rand(1)),
iscrowd=int(torch.randint(2, size=(1,))),
......@@ -832,11 +846,27 @@ class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase):
with pytest.raises(ValueError, match="Index must be of type integer"):
dataset[:2]
def test_segmentation_kind(self):
if isinstance(self, CocoCaptionsTestCase):
return
for segmentation_kind in ("list", "rle", "rle_encoded"):
config = {"segmentation_kind": segmentation_kind}
with self.create_dataset(config) as (dataset, _):
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all")
list(dataset)
config = {"segmentation_kind": "bad"}
with self.create_dataset(config) as (dataset, _):
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all")
with pytest.raises(ValueError, match="COCO segmentation expected to be a dict or a list"):
list(dataset)
class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions
def _create_annotations(self, image_ids, num_annotations_per_image):
def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"):
captions = [str(idx) for idx in range(num_annotations_per_image)]
annotations = combinations_grid(image_id=image_ids, caption=captions)
for id, annotation in enumerate(annotations):
......
......@@ -359,11 +359,14 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
def segmentation_to_mask(segmentation, *, canvas_size):
from pycocotools import mask
segmentation = (
mask.frPyObjects(segmentation, *canvas_size)
if isinstance(segmentation, dict)
else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
)
if isinstance(segmentation, dict):
# if counts is a string, it is already an encoded RLE mask
if not isinstance(segmentation["counts"], str):
segmentation = mask.frPyObjects(segmentation, *canvas_size)
elif isinstance(segmentation, list):
segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size))
else:
raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}")
return torch.from_numpy(mask.decode(segmentation))
def wrapper(idx, sample):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment