"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "717d15719c713fd3ee9ab0d8eb3d98116758036e"
Commit 8f85bea9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 453477038
parent 03f7782d
...@@ -199,15 +199,11 @@ class SpineNet(tf.keras.Model): ...@@ -199,15 +199,11 @@ class SpineNet(tf.keras.Model):
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if activation == 'relu':
self._activation_fn = tf.nn.relu
elif activation == 'swish':
self._activation_fn = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
self._init_block_fn = 'bottleneck' self._init_block_fn = 'bottleneck'
self._num_init_blocks = 2 self._num_init_blocks = 2
self._set_activation_fn(activation)
if use_sync_bn: if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization self._norm = layers.experimental.SyncBatchNormalization
else: else:
...@@ -232,6 +228,14 @@ class SpineNet(tf.keras.Model): ...@@ -232,6 +228,14 @@ class SpineNet(tf.keras.Model):
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(SpineNet, self).__init__(inputs=inputs, outputs=endpoints) super(SpineNet, self).__init__(inputs=inputs, outputs=endpoints)
def _set_activation_fn(self, activation):
if activation == 'relu':
self._activation_fn = tf.nn.relu
elif activation == 'swish':
self._activation_fn = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
def _block_group(self, def _block_group(self,
inputs: tf.Tensor, inputs: tf.Tensor,
filters: int, filters: int,
......
...@@ -122,6 +122,18 @@ class SpineNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -122,6 +122,18 @@ class SpineNetTest(parameterized.TestCase, tf.test.TestCase):
# If the serialization was successful, the new config should match the old. # If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config()) self.assertAllEqual(network.get_config(), new_network.get_config())
@parameterized.parameters(
('relu', tf.nn.relu),
('swish', tf.nn.swish)
)
def test_activation(self, activation, activation_fn):
model = spinenet.SpineNet(activation=activation)
self.assertEqual(model._activation_fn, activation_fn)
def test_invalid_activation_raises_valurerror(self):
with self.assertRaises(ValueError):
spinenet.SpineNet(activation='invalid_activation_name')
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment