Commit cb41f780 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

refactored extended coco

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/179

Refactored extended coco to fix lint errors and also simpler error reporting.

Differential Revision: D34365252

fbshipit-source-id: 8bf221eba5b8c5e63ddcf5ca19d7486726aff797
parent d8bdc633
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import json import json
import logging import logging
import shlex import shlex
...@@ -141,20 +141,139 @@ def valid_bbox(bbox_xywh: List[int], img_w: int, img_h: int) -> bool: ...@@ -141,20 +141,139 @@ def valid_bbox(bbox_xywh: List[int], img_w: int, img_h: int) -> bool:
return True return True
def convert_coco_annotations(
anno_dict_list: List[Dict], record: Dict, remapped_id: Dict, error_report: Dict
):
"""
Converts annotations format of coco to internal format while applying
some filtering
"""
converted_annotations = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same. This fails
# only when the data parsing logic or the annotation file is buggy.
assert anno["image_id"] == record["image_id"]
assert anno.get("ignore", 0) == 0
# Copy fields that do not need additional conversion
fields_to_copy = [
"iscrowd",
"bbox",
"bbox_mode",
"keypoints",
"category_id",
"extras",
"point_coords",
"point_labels",
]
# NOTE: maybe use MetadataCatalog for this
obj = {field: anno[field] for field in fields_to_copy if field in anno}
# Filter out bad annotations where category do not match
if obj.get("category_id", None) not in remapped_id:
continue
# Bounding boxes: convert and filter out bad bounding box annotations
bbox_object = obj.get("bbox", None)
if bbox_object:
if "bbox_mode" in obj:
bbox_object = BoxMode.convert(
bbox_object, obj["bbox_mode"], BoxMode.XYWH_ABS
)
else:
# Assume default box mode is always (x, y, w h)
error_report["without_bbox_mode"].cnt += 1
obj["bbox_mode"] = (
BoxMode.XYWHA_ABS if len(obj["bbox"]) == 5 else BoxMode.XYWH_ABS
)
if (
record.get("width")
and record.get("height")
and not valid_bbox(bbox_object, record["width"], record["height"])
):
error_report["without_valid_bounding_box"].cnt += 1
continue
# Segmentation: filter and add segmentation
segm = anno.get("segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if not isinstance(segm, dict):
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
error_report["without_valid_segmentation"].cnt += 1
continue # ignore this instance
obj["segmentation"] = segm
# Remap ids
obj["category_id"] = remapped_id[obj["category_id"]]
converted_annotations.append(obj)
return converted_annotations
# Error entry class for reporting coco conversion issues
class ErrorEntry:
def __init__(self, error_name, msg, cnt=0):
self.error_name = error_name
self.cnt = cnt
self.msg = msg
def __repr__(self):
return f"{self.msg} for {self.error_name}, count = {self.cnt}"
def print_conversion_report(ann_error_report, image_error_report, ex_warning_fn):
# Report image errors
report_str = ""
for error_key in image_error_report:
if image_error_report[error_key].cnt > 0:
report_str += f"\t{image_error_report[error_key]}\n"
if error_key == "ignore_image_root" and ex_warning_fn:
report_str += f"\texample file name {ex_warning_fn}\n"
# Report annotation errors
for error_key in ann_error_report:
if ann_error_report[error_key].cnt > 0:
report_str += f"\t{ann_error_report[error_key]}\n"
if len(report_str):
logger.warning(f"Conversion issues:\n{report_str}")
def convert_to_dict_list( def convert_to_dict_list(
image_root: str, image_root: str,
id_map: Dict, remapped_id: Dict,
imgs: Dict, imgs: List[Dict],
anns: Dict, anns: List[Dict],
dataset_name: Optional[str] = None, dataset_name: Optional[str] = None,
image_direct_copy_keys: List[str] = None, image_direct_copy_keys: Optional[List[str]] = None,
) -> List[Dict]: ) -> List[Dict]:
num_instances_without_valid_segmentation = 0
num_instances_without_valid_bounding_box = 0 ann_error_report = {
dataset_dicts = [] name: ErrorEntry(name, msg, 0)
count_ignore_image_root_warning = 0 for name, msg in [
("without_valid_segmentation", "Instance filtered"),
("without_valid_bounding_box", "Instance filtered"),
("without_bbox_mode", "Warning"),
]
}
image_error_report = {
name: ErrorEntry(name, msg, 0)
for name, msg in [
("ignore_image_root", f"Image root ignored {image_root}"),
("no_annotations", "Image filtered"),
]
}
ex_warning_fn = None
default_record = {"dataset_name": dataset_name} if dataset_name else {}
converted_dict_list = []
for (img_dict, anno_dict_list) in zip(imgs, anns): for (img_dict, anno_dict_list) in zip(imgs, anns):
record = {} record = copy.deepcopy(default_record)
# NOTE: besides using (relative path) in the "file_name" filed to represent # NOTE: besides using (relative path) in the "file_name" filed to represent
# the image resource, "extended coco" also supports using uri which # the image resource, "extended coco" also supports using uri which
# represents an image using a single string, eg. "everstore_handle://xxx", # represents an image using a single string, eg. "everstore_handle://xxx",
...@@ -162,125 +281,45 @@ def convert_to_dict_list( ...@@ -162,125 +281,45 @@ def convert_to_dict_list(
record["file_name"] = os.path.join(image_root, img_dict["file_name"]) record["file_name"] = os.path.join(image_root, img_dict["file_name"])
else: else:
if image_root is not None: if image_root is not None:
count_ignore_image_root_warning += 1 image_error_report["ignore_image_root"].cnt += 1
if count_ignore_image_root_warning == 1: ex_warning_fn = (
logger.warning( ex_warning_fn if ex_warning_fn else img_dict["file_name"]
( )
"Found '://' in file_name: {}, ignore image_root: {}"
"(logged once per dataset)."
).format(img_dict["file_name"], image_root)
)
record["file_name"] = img_dict["file_name"] record["file_name"] = img_dict["file_name"]
if image_direct_copy_keys: # Setup image info and id
for copy_key in image_direct_copy_keys:
assert (
copy_key in img_dict
), f"{copy_key} not in coco image dictionary entry"
record[copy_key] = img_dict[copy_key]
if "height" in img_dict or "width" in img_dict: if "height" in img_dict or "width" in img_dict:
record["height"] = img_dict["height"] record["height"] = img_dict["height"]
record["width"] = img_dict["width"] record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"] record["image_id"] = img_dict["id"]
objs = []
for anno in anno_dict_list:
# Check that the image_id in this annotation is the same. This fails
# only when the data parsing logic or the annotation file is buggy.
assert anno["image_id"] == image_id
assert anno.get("ignore", 0) == 0
obj = {
field: anno[field]
# NOTE: maybe use MetadataCatalog for this
for field in [
"iscrowd",
"bbox",
"bbox_mode",
"keypoints",
"category_id",
"extras",
"point_coords",
"point_labels",
]
if field in anno
}
bbox_object = obj.get("bbox", None)
if bbox_object is not None and "bbox_mode" in obj:
bbox_object = BoxMode.convert(
bbox_object, obj["bbox_mode"], BoxMode.XYWH_ABS
)
if (
record.get("width")
and record.get("height")
and not valid_bbox(bbox_object, record["width"], record["height"])
):
num_instances_without_valid_bounding_box += 1
continue
if obj.get("category_id", None) not in id_map:
continue
segm = anno.get("segmentation", None)
if segm: # either list[list[float]] or dict(RLE)
if not isinstance(segm, dict):
# filter out invalid polygons (< 3 points)
segm = [
poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6
]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
obj["segmentation"] = segm
if "bbox_mode" not in obj:
if len(obj["bbox"]) == 5:
obj["bbox_mode"] = BoxMode.XYWHA_ABS
else:
obj["bbox_mode"] = BoxMode.XYWH_ABS
if id_map:
obj["category_id"] = id_map[obj["category_id"]]
objs.append(obj)
record["annotations"] = objs
if len(objs) == 0:
continue
if dataset_name is not None:
record["dataset_name"] = dataset_name
dataset_dicts.append(record)
if count_ignore_image_root_warning > 0:
logger.warning(
"The 'ignore image_root: {}' warning occurred {} times".format(
image_root, count_ignore_image_root_warning
)
)
if num_instances_without_valid_segmentation > 0: # Convert annotation for dataset_dict
logger.warning( converted_anns = convert_coco_annotations(
"Filtered out {} instances without valid segmentation. " anno_dict_list, record, remapped_id, ann_error_report
"There might be issues in your dataset generation process.".format(
num_instances_without_valid_segmentation
)
) )
if len(converted_anns) == 0:
image_error_report["no_annotations"].cnt += 1
continue
record["annotations"] = converted_anns
if num_instances_without_valid_bounding_box > 0: # Copy keys if additionally asked
logger.warning( if image_direct_copy_keys:
"Filtered out {} instances without valid bounding boxes. " for c_key in image_direct_copy_keys:
"There might be issues in your dataset generation process.".format( assert c_key in img_dict, f"{c_key} not in coco image entry annotation"
num_instances_without_valid_bounding_box record[c_key] = img_dict[c_key]
)
) converted_dict_list.append(record)
assert len(dataset_dicts) != 0, ( print_conversion_report(ann_error_report, image_error_report, ex_warning_fn)
f"Loaded zero entries from {dataset_name} empty. \n"
assert len(converted_dict_list) != 0, (
f"Loaded zero entries from {dataset_name}. \n"
f" Size of inputs (imgs={len(imgs)}, anns={len(anns)})\n" f" Size of inputs (imgs={len(imgs)}, anns={len(anns)})\n"
f" Filtered of inputs (seg={num_instances_without_valid_segmentation}," f" Image issues ({image_error_report})\n"
f" ={num_instances_without_valid_bounding_box}\n" f" Instance issues ({ann_error_report})\n"
) )
return dataset_dicts return converted_dict_list
def coco_text_load( def coco_text_load(
...@@ -344,33 +383,29 @@ def extended_coco_load( ...@@ -344,33 +383,29 @@ def extended_coco_load(
else: else:
coco_api = InMemoryCOCO(loaded_json) coco_api = InMemoryCOCO(loaded_json)
id_map = None # Collect classes and remap them starting from 0
# Get filtered classes
all_cat_ids = coco_api.getCatIds() all_cat_ids = coco_api.getCatIds()
all_cats = coco_api.loadCats(all_cat_ids) all_cats = coco_api.loadCats(all_cat_ids)
all_cat_names = [c["name"] for c in sorted(all_cats, key=lambda x: x["id"])]
# Setup classes to use for creating id map # Setup id remapping
classes_to_use = [c["name"] for c in sorted(all_cats, key=lambda x: x["id"])] remapped_id = {}
# Setup id map
id_map = {}
for cat_id, cat in zip(all_cat_ids, all_cats): for cat_id, cat in zip(all_cat_ids, all_cats):
if cat["name"] in classes_to_use: remapped_id[cat_id] = all_cat_names.index(cat["name"])
id_map[cat_id] = classes_to_use.index(cat["name"])
# Register dataset in metadata catalog # Register dataset in metadata catalog
if dataset_name is not None: if dataset_name is not None:
# overwrite attrs # overwrite attrs
meta_dict = MetadataCatalog.get(dataset_name).as_dict() meta_dict = MetadataCatalog.get(dataset_name).as_dict()
meta_dict["thing_classes"] = classes_to_use meta_dict["thing_classes"] = all_cat_names
meta_dict["thing_dataset_id_to_contiguous_id"] = id_map meta_dict["thing_dataset_id_to_contiguous_id"] = remapped_id
# update MetadataCatalog (cannot change inplace, has to remove) # update MetadataCatalog (cannot change inplace, have to remove)
MetadataCatalog.remove(dataset_name) MetadataCatalog.remove(dataset_name)
MetadataCatalog.get(dataset_name).set(**meta_dict) MetadataCatalog.get(dataset_name).set(**meta_dict)
# assert the change # assert the change
assert MetadataCatalog.get(dataset_name).thing_classes == classes_to_use assert MetadataCatalog.get(dataset_name).thing_classes == all_cat_names
# sort indices for reproducible results # Sort indices for reproducible results
img_ids = sorted(coco_api.imgs.keys()) img_ids = sorted(coco_api.imgs.keys())
imgs = coco_api.loadImgs(img_ids) imgs = coco_api.loadImgs(img_ids)
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
...@@ -379,7 +414,7 @@ def extended_coco_load( ...@@ -379,7 +414,7 @@ def extended_coco_load(
# Return the coco converted to record list # Return the coco converted to record list
return convert_to_dict_list( return convert_to_dict_list(
image_root, image_root,
id_map, remapped_id,
imgs, imgs,
anns, anns,
dataset_name, dataset_name,
......
...@@ -149,6 +149,7 @@ class TestD2GoDatasets(unittest.TestCase): ...@@ -149,6 +149,7 @@ class TestD2GoDatasets(unittest.TestCase):
ann_list, ann_list,
) )
self.assertEqual(len(out_dict_list), 1) self.assertEqual(len(out_dict_list), 1)
self.assertEqual(len(out_dict_list[0]["annotations"]), 1)
@tempdir @tempdir
def test_coco_injection(self, tmp_dir): def test_coco_injection(self, tmp_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment