Commit cb93c205 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 336914823
parent cf084d8b
...@@ -145,6 +145,7 @@ class MaskHead(hyperparams.Config): ...@@ -145,6 +145,7 @@ class MaskHead(hyperparams.Config):
num_convs: int = 4 num_convs: int = 4
num_filters: int = 256 num_filters: int = 256
use_separable_conv: bool = False use_separable_conv: bool = False
class_agnostic: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -156,7 +156,8 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, ...@@ -156,7 +156,8 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
activation=model_config.norm_activation.activation, activation=model_config.norm_activation.activation,
norm_momentum=model_config.norm_activation.norm_momentum, norm_momentum=model_config.norm_activation.norm_momentum,
norm_epsilon=model_config.norm_activation.norm_epsilon, norm_epsilon=model_config.norm_activation.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
class_agnostic=model_config.mask_head.class_agnostic)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
mask_target_size=( mask_target_size=(
......
...@@ -225,6 +225,7 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -225,6 +225,7 @@ class MaskHead(tf.keras.layers.Layer):
norm_epsilon=0.001, norm_epsilon=0.001,
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
class_agnostic=False,
**kwargs): **kwargs):
"""Initialize params to build the mask head. """Initialize params to build the mask head.
...@@ -248,6 +249,8 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -248,6 +249,8 @@ class MaskHead(tf.keras.layers.Layer):
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel. kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias. bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
class_agnostic: `bool`, if set, we use a single channel mask head that
is shared between all classes.
**kwargs: other keyword arguments passed to Layer. **kwargs: other keyword arguments passed to Layer.
""" """
super(MaskHead, self).__init__(**kwargs) super(MaskHead, self).__init__(**kwargs)
...@@ -263,6 +266,7 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -263,6 +266,7 @@ class MaskHead(tf.keras.layers.Layer):
'norm_epsilon': norm_epsilon, 'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer,
'class_agnostic': class_agnostic
} }
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -330,8 +334,13 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -330,8 +334,13 @@ class MaskHead(tf.keras.layers.Layer):
name='mask-upsampling') name='mask-upsampling')
self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs) self._deconv_bn = bn_op(name='mask-deconv-bn', **bn_kwargs)
if self._config_dict['class_agnostic']:
num_filters = 1
else:
num_filters = self._config_dict['num_classes']
conv_kwargs = { conv_kwargs = {
'filters': self._config_dict['num_classes'], 'filters': num_filters,
'kernel_size': 1, 'kernel_size': 1,
'padding': 'valid', 'padding': 'valid',
} }
...@@ -395,17 +404,27 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -395,17 +404,27 @@ class MaskHead(tf.keras.layers.Layer):
mask_height = height * self._config_dict['upsample_factor'] mask_height = height * self._config_dict['upsample_factor']
mask_width = width * self._config_dict['upsample_factor'] mask_width = width * self._config_dict['upsample_factor']
logits = tf.reshape(
logits, if self._config_dict['class_agnostic']:
[-1, num_rois, mask_height, mask_width, logits = tf.reshape(logits, [-1, num_rois, mask_height, mask_width, 1])
self._config_dict['num_classes']]) else:
logits = tf.reshape(
logits,
[-1, num_rois, mask_height, mask_width,
self._config_dict['num_classes']])
batch_indices = tf.tile( batch_indices = tf.tile(
tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois]) tf.expand_dims(tf.range(batch_size), axis=1), [1, num_rois])
mask_indices = tf.tile( mask_indices = tf.tile(
tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1]) tf.expand_dims(tf.range(num_rois), axis=0), [batch_size, 1])
if self._config_dict['class_agnostic']:
class_gather_indices = tf.zeros_like(roi_classes, dtype=tf.int32)
else:
class_gather_indices = tf.cast(roi_classes, dtype=tf.int32)
gather_indices = tf.stack( gather_indices = tf.stack(
[batch_indices, mask_indices, tf.cast(roi_classes, dtype=tf.int32)], [batch_indices, mask_indices, class_gather_indices],
axis=2) axis=2)
mask_outputs = tf.gather_nd( mask_outputs = tf.gather_nd(
tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices) tf.transpose(logits, [0, 1, 4, 2, 3]), gather_indices)
......
...@@ -120,6 +120,16 @@ class MaskHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -120,6 +120,16 @@ class MaskHeadTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual( self.assertAllEqual(
mask_head.get_config(), new_mask_head.get_config()) mask_head.get_config(), new_mask_head.get_config())
def test_forward_class_agnostic(self):
mask_head = instance_heads.MaskHead(
num_classes=3,
class_agnostic=True
)
roi_features = np.random.rand(2, 10, 14, 14, 16)
roi_classes = np.zeros((2, 10))
masks = mask_head([roi_features, roi_classes])
self.assertAllEqual(masks.numpy().shape, [2, 10, 28, 28])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment