Commit e0ad9ff2 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 387878897
parent 44cad43a
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
# Lint as: python3 # Lint as: python3
"""Mask R-CNN configuration definition.""" """Mask R-CNN configuration definition."""
import dataclasses
import os import os
from typing import List, Optional from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -79,6 +78,8 @@ class DataConfig(cfg.DataConfig): ...@@ -79,6 +78,8 @@ class DataConfig(cfg.DataConfig):
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
drop_remainder: bool = True drop_remainder: bool = True
# Number of examples in the data set, it's used to create the annotation file.
num_examples: int = -1
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import json import json
# Import libraries # Import libraries
from absl import logging from absl import logging
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -26,6 +27,7 @@ from pycocotools import mask as mask_api ...@@ -26,6 +27,7 @@ from pycocotools import mask as mask_api
import six import six
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn
from official.vision.beta.dataloaders import tf_example_decoder from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
from official.vision.beta.ops import mask_ops from official.vision.beta.ops import mask_ops
...@@ -240,10 +242,20 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None): ...@@ -240,10 +242,20 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
(boxes[j, k, 3] - boxes[j, k, 1]) * (boxes[j, k, 3] - boxes[j, k, 1]) *
(boxes[j, k, 2] - boxes[j, k, 0])) (boxes[j, k, 2] - boxes[j, k, 0]))
if 'masks' in groundtruths: if 'masks' in groundtruths:
mask = Image.open(six.BytesIO(groundtruths['masks'][i][j, k])) if isinstance(groundtruths['masks'][i][j, k], tf.Tensor):
width, height = mask.size mask = Image.open(
np_mask = ( six.BytesIO(groundtruths['masks'][i][j, k].numpy()))
np.array(mask.getdata()).reshape(height, width).astype(np.uint8)) width, height = mask.size
np_mask = (
np.array(mask.getdata()).reshape(height,
width).astype(np.uint8))
else:
mask = Image.open(
six.BytesIO(groundtruths['masks'][i][j, k]))
width, height = mask.size
np_mask = (
np.array(mask.getdata()).reshape(height,
width).astype(np.uint8))
np_mask[np_mask > 0] = 255 np_mask[np_mask > 0] = 255
encoded_mask = mask_api.encode(np.asfortranarray(np_mask)) encoded_mask = mask_api.encode(np.asfortranarray(np_mask))
ann['segmentation'] = encoded_mask ann['segmentation'] = encoded_mask
...@@ -271,11 +283,11 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None): ...@@ -271,11 +283,11 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
class COCOGroundtruthGenerator: class COCOGroundtruthGenerator:
"""Generates the groundtruth annotations from a single example.""" """Generates the groundtruth annotations from a single example."""
def __init__(self, file_pattern, num_examples, include_mask): def __init__(self, file_pattern, file_type, num_examples, include_mask):
self._file_pattern = file_pattern self._file_pattern = file_pattern
self._num_examples = num_examples self._num_examples = num_examples
self._include_mask = include_mask self._include_mask = include_mask
self._dataset_fn = tf.data.TFRecordDataset self._dataset_fn = dataset_fn.pick_dataset_fn(file_type)
def _parse_single_example(self, example): def _parse_single_example(self, example):
"""Parses a single serialized tf.Example proto. """Parses a single serialized tf.Example proto.
...@@ -308,7 +320,7 @@ class COCOGroundtruthGenerator: ...@@ -308,7 +320,7 @@ class COCOGroundtruthGenerator:
boxes = box_ops.denormalize_boxes( boxes = box_ops.denormalize_boxes(
decoded_tensors['groundtruth_boxes'], image_size) decoded_tensors['groundtruth_boxes'], image_size)
groundtruths = { groundtruths = {
'source_id': tf.string_to_number( 'source_id': tf.strings.to_number(
decoded_tensors['source_id'], out_type=tf.int64), decoded_tensors['source_id'], out_type=tf.int64),
'height': decoded_tensors['height'], 'height': decoded_tensors['height'],
'width': decoded_tensors['width'], 'width': decoded_tensors['width'],
...@@ -344,12 +356,13 @@ class COCOGroundtruthGenerator: ...@@ -344,12 +356,13 @@ class COCOGroundtruthGenerator:
def scan_and_generator_annotation_file(file_pattern: str, def scan_and_generator_annotation_file(file_pattern: str,
file_type: str,
num_samples: int, num_samples: int,
include_mask: bool, include_mask: bool,
annotation_file: str): annotation_file: str):
"""Scans and generate the COCO-style annotation JSON file given a dataset.""" """Scans and generate the COCO-style annotation JSON file given a dataset."""
groundtruth_generator = COCOGroundtruthGenerator( groundtruth_generator = COCOGroundtruthGenerator(
file_pattern, num_samples, include_mask) file_pattern, file_type, num_samples, include_mask)
generate_annotation_file(groundtruth_generator, annotation_file) generate_annotation_file(groundtruth_generator, annotation_file)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""RetinaNet task definition.""" """RetinaNet task definition."""
import os
from typing import Any, Optional, List, Tuple, Mapping from typing import Any, Optional, List, Tuple, Mapping
from absl import logging from absl import logging
...@@ -26,6 +27,7 @@ from official.vision.beta.dataloaders import maskrcnn_input ...@@ -26,6 +27,7 @@ from official.vision.beta.dataloaders import maskrcnn_input
from official.vision.beta.dataloaders import tf_example_decoder from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import coco_utils
from official.vision.beta.losses import maskrcnn_losses from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -259,10 +261,33 @@ class MaskRCNNTask(base_task.Task): ...@@ -259,10 +261,33 @@ class MaskRCNNTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32)) metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else: else:
self.coco_metric = coco_evaluator.COCOEvaluator( if self._task_config.annotation_file:
annotation_file=self._task_config.annotation_file, self.coco_metric = coco_evaluator.COCOEvaluator(
include_mask=self._task_config.model.include_mask, annotation_file=self._task_config.annotation_file,
per_category_metrics=self._task_config.per_category_metrics) include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
else:
annotation_path = os.path.join(self._logging_dir, 'annotation.json')
if tf.io.gfile.exists(annotation_path):
logging.info(
'annotation.json file exists, skipping creating the annotation'
' file.')
else:
if self._task_config.validation_data.num_examples <= 0:
logging.info('validation_data.num_examples needs to be > 0')
if not self._task_config.validation_data.input_path:
logging.info('Can not create annotation file for tfds.')
logging.info(
'Creating coco-style annotation file: %s', annotation_path)
coco_utils.scan_and_generator_annotation_file(
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
return metrics return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment