"torchvision/vscode:/vscode.git/clone" did not exist on "372f4faeeb89d644705891be411499fc750d571c"
Commit 7a5be20b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Implement thresholding instance masks with a configurable value.

PiperOrigin-RevId: 351890707
parent 374cff58
...@@ -31,11 +31,13 @@ from official.vision.beta.configs import decoders ...@@ -31,11 +31,13 @@ from official.vision.beta.configs import decoders
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoder(hyperparams.Config): class TfExampleDecoder(hyperparams.Config):
regenerate_source_id: bool = False regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoderLabelMap(hyperparams.Config): class TfExampleDecoderLabelMap(hyperparams.Config):
regenerate_source_id: bool = False regenerate_source_id: bool = False
mask_binarize_threshold: Optional[float] = None
label_map: str = '' label_map: str = ''
......
...@@ -34,7 +34,8 @@ class TfExampleDecoder(decoder.Decoder): ...@@ -34,7 +34,8 @@ class TfExampleDecoder(decoder.Decoder):
def __init__(self, def __init__(self,
include_mask=False, include_mask=False,
regenerate_source_id=False): regenerate_source_id=False,
mask_binarize_threshold=None):
self._include_mask = include_mask self._include_mask = include_mask
self._regenerate_source_id = regenerate_source_id self._regenerate_source_id = regenerate_source_id
self._keys_to_features = { self._keys_to_features = {
...@@ -50,6 +51,7 @@ class TfExampleDecoder(decoder.Decoder): ...@@ -50,6 +51,7 @@ class TfExampleDecoder(decoder.Decoder):
'image/object/area': tf.io.VarLenFeature(tf.float32), 'image/object/area': tf.io.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.io.VarLenFeature(tf.int64), 'image/object/is_crowd': tf.io.VarLenFeature(tf.int64),
} }
self._mask_binarize_threshold = mask_binarize_threshold
if include_mask: if include_mask:
self._keys_to_features.update({ self._keys_to_features.update({
'image/object/mask': tf.io.VarLenFeature(tf.string), 'image/object/mask': tf.io.VarLenFeature(tf.string),
...@@ -151,6 +153,9 @@ class TfExampleDecoder(decoder.Decoder): ...@@ -151,6 +153,9 @@ class TfExampleDecoder(decoder.Decoder):
if self._include_mask: if self._include_mask:
masks = self._decode_masks(parsed_tensors) masks = self._decode_masks(parsed_tensors)
if self._mask_binarize_threshold is not None:
masks = tf.cast(masks > self._mask_binarize_threshold, tf.float32)
decoded_tensors = { decoded_tensors = {
'source_id': source_id, 'source_id': source_id,
'image': image, 'image': image,
......
...@@ -27,9 +27,11 @@ from official.vision.beta.dataloaders import tf_example_decoder ...@@ -27,9 +27,11 @@ from official.vision.beta.dataloaders import tf_example_decoder
class TfExampleDecoderLabelMap(tf_example_decoder.TfExampleDecoder): class TfExampleDecoderLabelMap(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder.""" """Tensorflow Example proto decoder."""
def __init__(self, label_map, include_mask=False, regenerate_source_id=False): def __init__(self, label_map, include_mask=False, regenerate_source_id=False,
mask_binarize_threshold=None):
super(TfExampleDecoderLabelMap, self).__init__( super(TfExampleDecoderLabelMap, self).__init__(
include_mask=include_mask, regenerate_source_id=regenerate_source_id) include_mask=include_mask, regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=mask_binarize_threshold)
self._keys_to_features.update({ self._keys_to_features.update({
'image/object/class/text': tf.io.VarLenFeature(tf.string), 'image/object/class/text': tf.io.VarLenFeature(tf.string),
}) })
......
...@@ -110,12 +110,14 @@ class MaskRCNNTask(base_task.Task): ...@@ -110,12 +110,14 @@ class MaskRCNNTask(base_task.Task):
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
include_mask=self._task_config.model.include_mask, include_mask=self._task_config.model.include_mask,
regenerate_source_id=decoder_cfg.regenerate_source_id) regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
elif params.decoder.type == 'label_map_decoder': elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap( decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map, label_map=decoder_cfg.label_map,
include_mask=self._task_config.model.include_mask, include_mask=self._task_config.model.include_mask,
regenerate_source_id=decoder_cfg.regenerate_source_id) regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold)
else: else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment