Commit 9239c294 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 369697787
parent 7678a1e9
......@@ -16,7 +16,7 @@
"""RetinaNet configuration definition."""
import os
from typing import Dict, List, Optional, Tuple
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -88,12 +88,19 @@ class Losses(hyperparams.Config):
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class AttributeHead(hyperparams.Config):
name: str = ''
type: str = 'regression'
size: int = 1
@dataclasses.dataclass
class RetinaNetHead(hyperparams.Config):
num_convs: int = 4
num_filters: int = 256
use_separable_conv: bool = False
attribute_heads: Optional[Dict[str, Tuple[str, int]]] = None
attribute_heads: Optional[List[AttributeHead]] = None
@dataclasses.dataclass
......
......@@ -223,7 +223,9 @@ def build_retinanet(
num_anchors_per_location=num_anchors_per_location,
num_convs=head_config.num_convs,
num_filters=head_config.num_filters,
attribute_heads=head_config.attribute_heads,
attribute_heads=[
cfg.as_dict() for cfg in (head_config.attribute_heads or [])
],
use_separable_conv=head_config.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -75,21 +75,36 @@ class MaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('resnet', (640, 640)),
('resnet', (None, None)),
('resnet', (640, 640), False),
('resnet', (None, None), True),
)
def test_builder(self, backbone_type, input_size):
def test_builder(self, backbone_type, input_size, has_att_heads):
num_classes = 2
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
if has_att_heads:
attribute_heads_config = [
retinanet_cfg.AttributeHead(name='att1'),
retinanet_cfg.AttributeHead(
name='att2', type='classification', size=2),
]
else:
attribute_heads_config = None
model_config = retinanet_cfg.RetinaNet(
num_classes=num_classes,
backbone=backbones.Backbone(type=backbone_type))
backbone=backbones.Backbone(type=backbone_type),
head=retinanet_cfg.RetinaNetHead(
attribute_heads=attribute_heads_config))
l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_retinanet(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
if has_att_heads:
self.assertEqual(model_config.head.attribute_heads[0].as_dict(),
dict(name='att1', type='regression', size=1))
self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
dict(name='att2', type='classification', size=2))
class VideoClassificationModelBuilderTest(parameterized.TestCase,
......
......@@ -14,7 +14,7 @@
"""Contains definitions of dense prediction heads."""
from typing import List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Union
# Import libraries
......@@ -36,7 +36,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
num_anchors_per_location: int,
num_convs: int = 4,
num_filters: int = 256,
attribute_heads: Mapping[str, Tuple[str, int]] = None,
attribute_heads: List[Dict[str, Any]] = None,
use_separable_conv: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
......@@ -57,9 +57,10 @@ class RetinaNetHead(tf.keras.layers.Layer):
conv layers before the prediction.
num_filters: An `int` number that represents the number of filters of the
intermediate conv layers.
attribute_heads: If not None, a dict that contains
(attribute_name, attribute_config) for additional attribute heads.
`attribute_config` is a tuple of (attribute_type, attribute_size).
attribute_heads: If not None, a list that contains a dict for each
additional attribute head. Each dict consists of 3 key-value pairs:
`name`, `type` ('regression' or 'classification'), and `size` (number
of predicted values for each instance).
use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu',
......@@ -189,11 +190,12 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._att_convs = {}
self._att_norms = {}
for att_name, att_head in self._config_dict['attribute_heads'].items():
for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
att_type = att_config['type']
att_size = att_config['size']
att_convs_i = []
att_norms_i = []
att_type = att_head[0]
att_size = att_head[1]
# Build conv and norm layers.
for level in range(self._config_dict['min_level'],
......@@ -277,8 +279,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
boxes = {}
if self._config_dict['attribute_heads']:
attributes = {
att_name: {}
for att_name in self._config_dict['attribute_heads'].keys()
att_config['name']: {}
for att_config in self._config_dict['attribute_heads']
}
else:
attributes = {}
......@@ -306,7 +308,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
# attribute nets.
if self._config_dict['attribute_heads']:
for att_name in self._config_dict['attribute_heads'].keys():
for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
x = this_level_features
for conv, norm in zip(self._att_convs[att_name],
self._att_norms[att_name][i]):
......
......@@ -33,7 +33,7 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
)
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads):
if has_att_heads:
attribute_heads = {'depth': ('regression', 1)}
attribute_heads = [dict(name='depth', type='regression', size=1)]
else:
attribute_heads = None
......
......@@ -101,7 +101,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
anchor_boxes = None
if has_att_heads:
attribute_heads = {'depth': ('regression', 1)}
attribute_heads = [dict(name='depth', type='regression', size=1)]
else:
attribute_heads = None
......@@ -181,7 +181,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
max_level=max_level)
if has_att_heads:
attribute_heads = {'depth': ('regression', 1)}
attribute_heads = [dict(name='depth', type='regression', size=1)]
else:
attribute_heads = None
head = dense_prediction_heads.RetinaNetHead(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment