Unverified Commit c4ce3a9e authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

make keys to panoptic masks configurable in `TfExampleDecoder`

parent a9e15bd9
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -51,10 +51,16 @@ class Parser(hyperparams.Config):
dtype = 'float32'
@dataclasses.dataclass
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
panoptic_category_mask_key: str = 'image/panoptic/category_mask'
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: common.TfExampleDecoder = common.TfExampleDecoder()
simple_decoder: TfExampleDecoder = TfExampleDecoder()
@dataclasses.dataclass
......@@ -164,7 +170,7 @@ class PanopticDeeplabTask(cfg.TaskConfig):
evaluation: Evaluation = Evaluation()
@exp_factory.register_config_factory('panoptic_deeplab_coco')
@exp_factory.register_config_factory('panoptic_deeplab_resnet_coco')
def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
"""COCO panoptic segmentation with Panoptic Deeplab."""
train_steps = 200000
......
......@@ -41,16 +41,22 @@ def _compute_gaussian_from_std(sigma):
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id):
def __init__(
self,
regenerate_source_id: bool,
panoptic_category_mask_key: str = 'image/panoptic/category_mask',
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'):
super(TfExampleDecoder,
self).__init__(
include_mask=True,
regenerate_source_id=regenerate_source_id)
self._panoptic_category_mask_key = panoptic_category_mask_key
self._panoptic_instance_mask_key = panoptic_instance_mask_key
self._panoptic_keys_to_features = {
'image/panoptic/category_mask':
panoptic_category_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/panoptic/instance_mask':
panoptic_instance_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value='')
}
......@@ -61,9 +67,9 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
serialized_example, self._panoptic_keys_to_features)
category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'], channels=1)
parsed_tensors[self._panoptic_category_mask_key], channels=1)
instance_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/instance_mask'], channels=1)
parsed_tensors[self._panoptic_instance_mask_key], channels=1)
category_mask.set_shape([None, None, 1])
instance_mask.set_shape([None, None, 1])
......
......@@ -90,7 +90,9 @@ class PanopticDeeplabTask(base_task.Task):
if params.decoder.type == 'simple_decoder':
decoder = panoptic_deeplab_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
regenerate_source_id=decoder_cfg.regenerate_source_id,
panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment