"vscode:/vscode.git/clone" did not exist on "cd135317d24cc8f48a883b6f9c025fc60870a13b"
Commit e0ad9ff2 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 387878897
parent 44cad43a
......@@ -15,11 +15,10 @@
# Lint as: python3
"""Mask R-CNN configuration definition."""
import dataclasses
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -79,6 +78,8 @@ class DataConfig(cfg.DataConfig):
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
# Number of examples in the data set, it's used to create the annotation file.
num_examples: int = -1
@dataclasses.dataclass
......
......@@ -18,6 +18,7 @@ import copy
import json
# Import libraries
from absl import logging
import numpy as np
from PIL import Image
......@@ -26,6 +27,7 @@ from pycocotools import mask as mask_api
import six
import tensorflow as tf
from official.common import dataset_fn
from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import mask_ops
......@@ -240,10 +242,20 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
(boxes[j, k, 3] - boxes[j, k, 1]) *
(boxes[j, k, 2] - boxes[j, k, 0]))
if 'masks' in groundtruths:
mask = Image.open(six.BytesIO(groundtruths['masks'][i][j, k]))
if isinstance(groundtruths['masks'][i][j, k], tf.Tensor):
mask = Image.open(
six.BytesIO(groundtruths['masks'][i][j, k].numpy()))
width, height = mask.size
np_mask = (
np.array(mask.getdata()).reshape(height,
width).astype(np.uint8))
else:
mask = Image.open(
six.BytesIO(groundtruths['masks'][i][j, k]))
width, height = mask.size
np_mask = (
np.array(mask.getdata()).reshape(height, width).astype(np.uint8))
np.array(mask.getdata()).reshape(height,
width).astype(np.uint8))
np_mask[np_mask > 0] = 255
encoded_mask = mask_api.encode(np.asfortranarray(np_mask))
ann['segmentation'] = encoded_mask
......@@ -271,11 +283,11 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
class COCOGroundtruthGenerator:
"""Generates the groundtruth annotations from a single example."""
def __init__(self, file_pattern, num_examples, include_mask):
def __init__(self, file_pattern, file_type, num_examples, include_mask):
self._file_pattern = file_pattern
self._num_examples = num_examples
self._include_mask = include_mask
self._dataset_fn = tf.data.TFRecordDataset
self._dataset_fn = dataset_fn.pick_dataset_fn(file_type)
def _parse_single_example(self, example):
"""Parses a single serialized tf.Example proto.
......@@ -308,7 +320,7 @@ class COCOGroundtruthGenerator:
boxes = box_ops.denormalize_boxes(
decoded_tensors['groundtruth_boxes'], image_size)
groundtruths = {
'source_id': tf.string_to_number(
'source_id': tf.strings.to_number(
decoded_tensors['source_id'], out_type=tf.int64),
'height': decoded_tensors['height'],
'width': decoded_tensors['width'],
......@@ -344,12 +356,13 @@ class COCOGroundtruthGenerator:
def scan_and_generator_annotation_file(file_pattern: str,
file_type: str,
num_samples: int,
include_mask: bool,
annotation_file: str):
"""Scans and generate the COCO-style annotation JSON file given a dataset."""
groundtruth_generator = COCOGroundtruthGenerator(
file_pattern, num_samples, include_mask)
file_pattern, file_type, num_samples, include_mask)
generate_annotation_file(groundtruth_generator, annotation_file)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
"""RetinaNet task definition."""
import os
from typing import Any, Optional, List, Tuple, Mapping
from absl import logging
......@@ -26,6 +27,7 @@ from official.vision.beta.dataloaders import maskrcnn_input
from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import coco_utils
from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory
......@@ -259,10 +261,33 @@ class MaskRCNNTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else:
if self._task_config.annotation_file:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
else:
annotation_path = os.path.join(self._logging_dir, 'annotation.json')
if tf.io.gfile.exists(annotation_path):
logging.info(
'annotation.json file exists, skipping creating the annotation'
' file.')
else:
if self._task_config.validation_data.num_examples <= 0:
logging.info('validation_data.num_examples needs to be > 0')
if not self._task_config.validation_data.input_path:
logging.info('Can not create annotation file for tfds.')
logging.info(
'Creating coco-style annotation file: %s', annotation_path)
coco_utils.scan_and_generator_annotation_file(
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment