Unverified Commit 9ffeeb5f authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add support for evaluating masks and keypoints for custom dataset (#938)

parent f4d43ccf
...@@ -151,10 +151,12 @@ def convert_to_coco_api(ds): ...@@ -151,10 +151,12 @@ def convert_to_coco_api(ds):
for img_idx in range(len(ds)): for img_idx in range(len(ds)):
# find better way to get target # find better way to get target
# targets = ds.get_annotations(img_idx) # targets = ds.get_annotations(img_idx)
_, targets = ds[img_idx] img, targets = ds[img_idx]
image_id = targets["image_id"].item() image_id = targets["image_id"].item()
img_dict = {} img_dict = {}
img_dict['id'] = image_id img_dict['id'] = image_id
img_dict['height'] = img.shape[-2]
img_dict['width'] = img.shape[-1]
dataset['images'].append(img_dict) dataset['images'].append(img_dict)
bboxes = targets["boxes"] bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2] bboxes[:, 2:] -= bboxes[:, :2]
...@@ -162,7 +164,13 @@ def convert_to_coco_api(ds): ...@@ -162,7 +164,13 @@ def convert_to_coco_api(ds):
labels = targets['labels'].tolist() labels = targets['labels'].tolist()
areas = targets['area'].tolist() areas = targets['area'].tolist()
iscrowd = targets['iscrowd'].tolist() iscrowd = targets['iscrowd'].tolist()
# TODO need to add masks as well if 'masks' in targets:
masks = targets['masks']
# make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if 'keypoints' in targets:
keypoints = targets['keypoints']
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes) num_objs = len(bboxes)
for i in range(num_objs): for i in range(num_objs):
ann = {} ann = {}
...@@ -173,6 +181,11 @@ def convert_to_coco_api(ds): ...@@ -173,6 +181,11 @@ def convert_to_coco_api(ds):
ann['area'] = areas[i] ann['area'] = areas[i]
ann['iscrowd'] = iscrowd[i] ann['iscrowd'] = iscrowd[i]
ann['id'] = ann_id ann['id'] = ann_id
if 'masks' in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if 'keypoints' in targets:
ann['keypoints'] = keypoints[i]
ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3])
dataset['annotations'].append(ann) dataset['annotations'].append(ann)
ann_id += 1 ann_id += 1
dataset['categories'] = [{'id': i} for i in sorted(categories)] dataset['categories'] = [{'id': i} for i in sorted(categories)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment