Commit 08a9f1f8 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 458114554
parent e723d9da
......@@ -124,6 +124,7 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4,
share_classification_heads: bool = False,
**kwargs):
"""Initializes a RetinaNet quantized head.
......@@ -156,8 +157,13 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
box. For example, `num_params_per_anchor` would be 4 for axis-aligned
anchor boxes specified by their y-centers, x-centers, heights, and
widths.
share_classification_heads: A `bool` that indicates whethere
sharing weights among the main and attribute classification heads. Not
used in the QAT model.
**kwargs: Additional keyword arguments to be passed.
"""
del share_classification_heads
super().__init__(**kwargs)
self._config_dict = {
'min_level': min_level,
......
......@@ -107,6 +107,7 @@ class RetinaNetHead(hyperparams.Config):
num_filters: int = 256
use_separable_conv: bool = False
attribute_heads: List[AttributeHead] = dataclasses.field(default_factory=list)
share_classification_heads: bool = False
@dataclasses.dataclass
......
......@@ -293,6 +293,7 @@ def build_retinanet(
attribute_heads=[
cfg.as_dict() for cfg in (head_config.attribute_heads or [])
],
share_classification_heads=head_config.share_classification_heads,
use_separable_conv=head_config.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -37,6 +37,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
num_convs: int = 4,
num_filters: int = 256,
attribute_heads: Optional[List[Dict[str, Any]]] = None,
share_classification_heads: bool = False,
use_separable_conv: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
......@@ -62,6 +63,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
additional attribute head. Each dict consists of 3 key-value pairs:
`name`, `type` ('regression' or 'classification'), and `size` (number
of predicted values for each instance).
share_classification_heads: A `bool` that indicates whethere
sharing weights among the main and attribute classification heads.
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
......@@ -88,6 +91,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
'num_convs': num_convs,
'num_filters': num_filters,
'attribute_heads': attribute_heads,
'share_classification_heads': share_classification_heads,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
......@@ -216,7 +220,11 @@ class RetinaNetHead(tf.keras.layers.Layer):
this_level_att_norms = []
for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']:
att_conv_name = '{}-conv_{}'.format(att_name, i)
if self._config_dict[
'share_classification_heads'] and att_type == 'classification':
att_conv_name = 'classnet-conv_{}'.format(i)
else:
att_conv_name = '{}-conv_{}'.format(att_name, i)
if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer'])
......
......@@ -25,14 +25,15 @@ from official.vision.modeling.heads import dense_prediction_heads
class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(False, False, False),
(False, True, False),
(True, False, True),
(True, True, True),
(False, False, False, None, False),
(False, True, False, None, False),
(True, False, True, 'regression', False),
(True, True, True, 'classification', True),
)
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads):
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads,
att_type, share_classification_heads):
if has_att_heads:
attribute_heads = [dict(name='depth', type='regression', size=1)]
attribute_heads = [dict(name='depth', type=att_type, size=1)]
else:
attribute_heads = None
......@@ -44,6 +45,7 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
num_convs=2,
num_filters=256,
attribute_heads=attribute_heads,
share_classification_heads=share_classification_heads,
use_separable_conv=use_separable_conv,
activation='relu',
use_sync_bn=use_sync_bn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment