Commit 8c011f33 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480725172
parent 25baa631
...@@ -220,10 +220,6 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -220,10 +220,6 @@ class RetinaNetHead(tf.keras.layers.Layer):
this_level_att_norms = [] this_level_att_norms = []
for i in range(self._config_dict['num_convs']): for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']: if level == self._config_dict['min_level']:
if self._config_dict[
'share_classification_heads'] and att_type == 'classification':
att_conv_name = 'classnet-conv_{}'.format(i)
else:
att_conv_name = '{}-conv_{}'.format(att_name, i) att_conv_name = '{}-conv_{}'.format(att_name, i)
if 'kernel_initializer' in conv_kwargs: if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer( conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
...@@ -321,7 +317,8 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -321,7 +317,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
x = conv(x) x = conv(x)
x = norm(x) x = norm(x)
x = self._activation(x) x = self._activation(x)
scores[str(level)] = self._classifier(x) classnet_x = x
scores[str(level)] = self._classifier(classnet_x)
# box net. # box net.
x = this_level_features x = this_level_features
...@@ -335,6 +332,12 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -335,6 +332,12 @@ class RetinaNetHead(tf.keras.layers.Layer):
if self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
for att_config in self._config_dict['attribute_heads']: for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name'] att_name = att_config['name']
att_type = att_config['type']
if self._config_dict[
'share_classification_heads'] and att_type == 'classification':
attributes[att_name][str(level)] = self._att_predictors[att_name](
classnet_x)
else:
x = this_level_features x = this_level_features
for conv, norm in zip(self._att_convs[att_name], for conv, norm in zip(self._att_convs[att_name],
self._att_norms[att_name][i]): self._att_norms[att_name][i]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment