Commit dbf19582 authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Internal changes.

PiperOrigin-RevId: 410294118
parent fdac80fe
...@@ -110,6 +110,8 @@ def get_activation(identifier, use_keras_layer=False): ...@@ -110,6 +110,8 @@ def get_activation(identifier, use_keras_layer=False):
"swish": "swish", "swish": "swish",
"sigmoid": "sigmoid", "sigmoid": "sigmoid",
"relu6": tf.nn.relu6, "relu6": tf.nn.relu6,
"hard_swish": activations.hard_swish,
"hard_sigmoid": activations.hard_sigmoid,
} }
if identifier in keras_layer_allowlist: if identifier in keras_layer_allowlist:
return tf.keras.layers.Activation(keras_layer_allowlist[identifier]) return tf.keras.layers.Activation(keras_layer_allowlist[identifier])
......
...@@ -80,39 +80,6 @@ def get_padding_for_kernel_size(kernel_size): ...@@ -80,39 +80,6 @@ def get_padding_for_kernel_size(kernel_size):
kernel_size)) kernel_size))
def hard_swish(x: tf.Tensor) -> tf.Tensor:
"""A Swish6/H-Swish activation function.
Reference: Section 5.2 of Howard et al. "Searching for MobileNet V3."
https://arxiv.org/pdf/1905.02244.pdf
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.nn.relu6(x + 3.) * (1. / 6.)
tf.keras.utils.get_custom_objects().update({'hard_swish': hard_swish})
def simple_swish(x: tf.Tensor) -> tf.Tensor:
"""A swish/silu activation function without custom gradients.
Useful for exporting to SavedModel to avoid custom gradient warnings.
Args:
x: the input tensor.
Returns:
The activation output.
"""
return x * tf.math.sigmoid(x)
tf.keras.utils.get_custom_objects().update({'simple_swish': simple_swish})
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitation(tf.keras.layers.Layer): class SqueezeExcitation(tf.keras.layers.Layer):
"""Creates a squeeze and excitation layer.""" """Creates a squeeze and excitation layer."""
......
...@@ -24,11 +24,6 @@ from official.vision.beta.modeling.layers import nn_layers ...@@ -24,11 +24,6 @@ from official.vision.beta.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase): class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
def test_hard_swish(self):
activation = tf.keras.layers.Activation('hard_swish')
output = activation(tf.constant([-3, -1.5, 0, 3]))
self.assertAllEqual(output, [0., -0.375, 0., 3.])
def test_scale(self): def test_scale(self):
scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.)) scale = nn_layers.Scale(initializer=tf.keras.initializers.constant(10.))
output = scale(3.) output = scale(3.)
......
...@@ -22,6 +22,7 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union ...@@ -22,6 +22,7 @@ from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
# Default kernel weight decay that may be overridden # Default kernel weight decay that may be overridden
...@@ -323,7 +324,8 @@ class ConvBlock(tf.keras.layers.Layer): ...@@ -323,7 +324,8 @@ class ConvBlock(tf.keras.layers.Layer):
self._use_buffered_input = use_buffered_input self._use_buffered_input = use_buffered_input
if activation is not None: if activation is not None:
self._activation_layer = tf.keras.layers.Activation(activation) self._activation_layer = tf_utils.get_activation(
activation, use_keras_layer=True)
else: else:
self._activation_layer = None self._activation_layer = None
......
...@@ -338,7 +338,7 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -338,7 +338,7 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
predicted = tf.concat(predicted, axis=1) predicted = tf.concat(predicted, axis=1)
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected, atol=1e-4)
self.assertAllClose( self.assertAllClose(
predicted, predicted,
...@@ -349,7 +349,8 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -349,7 +349,8 @@ class MovinetLayersTest(parameterized.TestCase, tf.test.TestCase):
[[[3., 3., 3.]], [[[3., 3., 3.]],
[[3., 3., 3.]]], [[3., 3., 3.]]],
[[[4., 4., 4.]], [[[4., 4., 4.]],
[[4., 4., 4.]]]]]) [[4., 4., 4.]]]]],
atol=1e-4)
def test_stream_movinet_block(self): def test_stream_movinet_block(self):
block = movinet_layers.MovinetBlock( block = movinet_layers.MovinetBlock(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment