Commit 09c999ba authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 402725910
parent c50daa27
...@@ -44,6 +44,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -44,6 +44,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4,
**kwargs): **kwargs):
"""Initializes a RetinaNet head. """Initializes a RetinaNet head.
...@@ -72,6 +73,10 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -72,6 +73,10 @@ class RetinaNetHead(tf.keras.layers.Layer):
kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for kernel_regularizer: A `tf.keras.regularizers.Regularizer` object for
Conv2D. Default is None. Conv2D. Default is None.
bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D. bias_regularizer: A `tf.keras.regularizers.Regularizer` object for Conv2D.
num_params_per_anchor: Number of parameters required to specify an anchor
box. For example, `num_params_per_anchor` would be 4 for axis-aligned
anchor boxes specified by their y-centers, x-centers, heights, and
widths.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
super(RetinaNetHead, self).__init__(**kwargs) super(RetinaNetHead, self).__init__(**kwargs)
...@@ -90,6 +95,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -90,6 +95,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
'norm_epsilon': norm_epsilon, 'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer,
'num_params_per_anchor': num_params_per_anchor,
} }
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -170,7 +176,8 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -170,7 +176,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._box_norms.append(this_level_box_norms) self._box_norms.append(this_level_box_norms)
box_regressor_kwargs = { box_regressor_kwargs = {
'filters': 4 * self._config_dict['num_anchors_per_location'], 'filters': (self._config_dict['num_params_per_anchor'] *
self._config_dict['num_anchors_per_location']),
'kernel_size': 3, 'kernel_size': 3,
'padding': 'same', 'padding': 'same',
'bias_initializer': tf.zeros_initializer(), 'bias_initializer': tf.zeros_initializer(),
...@@ -265,7 +272,8 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -265,7 +272,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
- key: A `str` of the level of the multilevel predictions. - key: A `str` of the level of the multilevel predictions.
- values: A `tf.Tensor` of the box scores predicted from a particular - values: A `tf.Tensor` of the box scores predicted from a particular
feature level, whose shape is feature level, whose shape is
[batch, height_l, width_l, 4 * num_anchors_per_location]. [batch, height_l, width_l,
num_params_per_anchor * num_anchors_per_location].
attributes: a dict of (attribute_name, attribute_prediction). Each attributes: a dict of (attribute_name, attribute_prediction). Each
`attribute_prediction` is a dict of: `attribute_prediction` is a dict of:
- key: `str`, the level of the multilevel predictions. - key: `str`, the level of the multilevel predictions.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment