"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "301e2e9802d29352cbc6c4824ef44614cb6bd0cb"
Commit 8c011f33 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480725172
parent 25baa631
...@@ -220,11 +220,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -220,11 +220,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
this_level_att_norms = [] this_level_att_norms = []
for i in range(self._config_dict['num_convs']): for i in range(self._config_dict['num_convs']):
if level == self._config_dict['min_level']: if level == self._config_dict['min_level']:
if self._config_dict[ att_conv_name = '{}-conv_{}'.format(att_name, i)
'share_classification_heads'] and att_type == 'classification':
att_conv_name = 'classnet-conv_{}'.format(i)
else:
att_conv_name = '{}-conv_{}'.format(att_name, i)
if 'kernel_initializer' in conv_kwargs: if 'kernel_initializer' in conv_kwargs:
conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer( conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
conv_kwargs['kernel_initializer']) conv_kwargs['kernel_initializer'])
...@@ -321,7 +317,8 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -321,7 +317,8 @@ class RetinaNetHead(tf.keras.layers.Layer):
x = conv(x) x = conv(x)
x = norm(x) x = norm(x)
x = self._activation(x) x = self._activation(x)
scores[str(level)] = self._classifier(x) classnet_x = x
scores[str(level)] = self._classifier(classnet_x)
# box net. # box net.
x = this_level_features x = this_level_features
...@@ -335,13 +332,19 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -335,13 +332,19 @@ class RetinaNetHead(tf.keras.layers.Layer):
if self._config_dict['attribute_heads']: if self._config_dict['attribute_heads']:
for att_config in self._config_dict['attribute_heads']: for att_config in self._config_dict['attribute_heads']:
att_name = att_config['name'] att_name = att_config['name']
x = this_level_features att_type = att_config['type']
for conv, norm in zip(self._att_convs[att_name], if self._config_dict[
self._att_norms[att_name][i]): 'share_classification_heads'] and att_type == 'classification':
x = conv(x) attributes[att_name][str(level)] = self._att_predictors[att_name](
x = norm(x) classnet_x)
x = self._activation(x) else:
attributes[att_name][str(level)] = self._att_predictors[att_name](x) x = this_level_features
for conv, norm in zip(self._att_convs[att_name],
self._att_norms[att_name][i]):
x = conv(x)
x = norm(x)
x = self._activation(x)
attributes[att_name][str(level)] = self._att_predictors[att_name](x)
return scores, boxes, attributes return scores, boxes, attributes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment