Commit 9239c294 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 369697787
parent 7678a1e9
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""RetinaNet configuration definition.""" """RetinaNet configuration definition."""
import os import os
from typing import Dict, List, Optional, Tuple from typing import List, Optional
import dataclasses import dataclasses
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
...@@ -88,12 +88,19 @@ class Losses(hyperparams.Config): ...@@ -88,12 +88,19 @@ class Losses(hyperparams.Config):
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
@dataclasses.dataclass
class AttributeHead(hyperparams.Config):
name: str = ''
type: str = 'regression'
size: int = 1
@dataclasses.dataclass @dataclasses.dataclass
class RetinaNetHead(hyperparams.Config): class RetinaNetHead(hyperparams.Config):
num_convs: int = 4 num_convs: int = 4
num_filters: int = 256 num_filters: int = 256
use_separable_conv: bool = False use_separable_conv: bool = False
attribute_heads: Optional[Dict[str, Tuple[str, int]]] = None attribute_heads: Optional[List[AttributeHead]] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -223,7 +223,9 @@ def build_retinanet( ...@@ -223,7 +223,9 @@ def build_retinanet(
num_anchors_per_location=num_anchors_per_location, num_anchors_per_location=num_anchors_per_location,
num_convs=head_config.num_convs, num_convs=head_config.num_convs,
num_filters=head_config.num_filters, num_filters=head_config.num_filters,
attribute_heads=head_config.attribute_heads, attribute_heads=[
cfg.as_dict() for cfg in (head_config.attribute_heads or [])
],
use_separable_conv=head_config.use_separable_conv, use_separable_conv=head_config.use_separable_conv,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
......
...@@ -75,21 +75,36 @@ class MaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -75,21 +75,36 @@ class MaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase): class RetinaNetBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
('resnet', (640, 640)), ('resnet', (640, 640), False),
('resnet', (None, None)), ('resnet', (None, None), True),
) )
def test_builder(self, backbone_type, input_size): def test_builder(self, backbone_type, input_size, has_att_heads):
num_classes = 2 num_classes = 2
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
if has_att_heads:
attribute_heads_config = [
retinanet_cfg.AttributeHead(name='att1'),
retinanet_cfg.AttributeHead(
name='att2', type='classification', size=2),
]
else:
attribute_heads_config = None
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
num_classes=num_classes, num_classes=num_classes,
backbone=backbones.Backbone(type=backbone_type)) backbone=backbones.Backbone(type=backbone_type),
head=retinanet_cfg.RetinaNetHead(
attribute_heads=attribute_heads_config))
l2_regularizer = tf.keras.regularizers.l2(5e-5) l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_retinanet( _ = factory.build_retinanet(
input_specs=input_specs, input_specs=input_specs,
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
if has_att_heads:
self.assertEqual(model_config.head.attribute_heads[0].as_dict(),
dict(name='att1', type='regression', size=1))
self.assertEqual(model_config.head.attribute_heads[1].as_dict(),
dict(name='att2', type='classification', size=2))
class VideoClassificationModelBuilderTest(parameterized.TestCase, class VideoClassificationModelBuilderTest(parameterized.TestCase,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Contains definitions of dense prediction heads.""" """Contains definitions of dense prediction heads."""
from typing import List, Mapping, Optional, Tuple, Union from typing import Any, Dict, List, Mapping, Optional, Union
# Import libraries # Import libraries
...@@ -36,7 +36,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -36,7 +36,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
num_anchors_per_location: int, num_anchors_per_location: int,
num_convs: int = 4, num_convs: int = 4,
num_filters: int = 256, num_filters: int = 256,
attribute_heads: Mapping[str, Tuple[str, int]] = None, attribute_heads: List[Dict[str, Any]] = None,
use_separable_conv: bool = False, use_separable_conv: bool = False,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
...@@ -57,9 +57,10 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -57,9 +57,10 @@ class RetinaNetHead(tf.keras.layers.Layer):
conv layers before the prediction. conv layers before the prediction.
num_filters: An `int` number that represents the number of filters of the num_filters: An `int` number that represents the number of filters of the
intermediate conv layers. intermediate conv layers.
attribute_heads: If not None, a dict that contains attribute_heads: If not None, a list that contains a dict for each
(attribute_name, attribute_config) for additional attribute heads. additional attribute head. Each dict consists of 3 key-value pairs:
`attribute_config` is a tuple of (attribute_type, attribute_size). `name`, `type` ('regression' or 'classification'), and `size` (number
of predicted values for each instance).
use_separable_conv: A `bool` that indicates whether the separable use_separable_conv: A `bool` that indicates whether the separable
convolution layers is used. convolution layers is used.
activation: A `str` that indicates which activation is used, e.g. 'relu', activation: A `str` that indicates which activation is used, e.g. 'relu',
...@@ -189,11 +190,12 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -189,11 +190,12 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._att_convs = {} self._att_convs = {}
self._att_norms = {} self._att_norms = {}
for att_name, att_head in self._config_dict['attribute_heads'].items(): for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
att_type = att_config['type']
att_size = att_config['size']
att_convs_i = [] att_convs_i = []
att_norms_i = [] att_norms_i = []
att_type = att_head[0]
att_size = att_head[1]
# Build conv and norm layers. # Build conv and norm layers.
for level in range(self._config_dict['min_level'], for level in range(self._config_dict['min_level'],
...@@ -277,8 +279,8 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -277,8 +279,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
boxes = {} boxes = {}
if self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
attributes = { attributes = {
att_name: {} att_config['name']: {}
for att_name in self._config_dict['attribute_heads'].keys() for att_config in self._config_dict['attribute_heads']
} }
else: else:
attributes = {} attributes = {}
...@@ -306,7 +308,8 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -306,7 +308,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
# attribute nets. # attribute nets.
if self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
for att_name in self._config_dict['attribute_heads'].keys(): for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name']
x = this_level_features x = this_level_features
for conv, norm in zip(self._att_convs[att_name], for conv, norm in zip(self._att_convs[att_name],
self._att_norms[att_name][i]): self._att_norms[att_name][i]):
......
...@@ -33,7 +33,7 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase): ...@@ -33,7 +33,7 @@ class RetinaNetHeadTest(parameterized.TestCase, tf.test.TestCase):
) )
def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads): def test_forward(self, use_separable_conv, use_sync_bn, has_att_heads):
if has_att_heads: if has_att_heads:
attribute_heads = {'depth': ('regression', 1)} attribute_heads = [dict(name='depth', type='regression', size=1)]
else: else:
attribute_heads = None attribute_heads = None
......
...@@ -101,7 +101,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -101,7 +101,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
anchor_boxes = None anchor_boxes = None
if has_att_heads: if has_att_heads:
attribute_heads = {'depth': ('regression', 1)} attribute_heads = [dict(name='depth', type='regression', size=1)]
else: else:
attribute_heads = None attribute_heads = None
...@@ -181,7 +181,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -181,7 +181,7 @@ class RetinaNetTest(parameterized.TestCase, tf.test.TestCase):
max_level=max_level) max_level=max_level)
if has_att_heads: if has_att_heads:
attribute_heads = {'depth': ('regression', 1)} attribute_heads = [dict(name='depth', type='regression', size=1)]
else: else:
attribute_heads = None attribute_heads = None
head = dense_prediction_heads.RetinaNetHead( head = dense_prediction_heads.RetinaNetHead(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment