Commit e37e8049 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Modify EfficientNet to support the functional subclassed model.

PiperOrigin-RevId: 277179351
parent f0141859
...@@ -14,4 +14,6 @@ ...@@ -14,4 +14,6 @@
# ============================================================================== # ==============================================================================
"""Activations package definition.""" """Activations package definition."""
from official.modeling.activations.gelu import gelu from official.modeling.activations.gelu import gelu
from official.modeling.activations.swish import swish from official.modeling.activations.swish import hard_swish
from official.modeling.activations.swish import identity
from official.modeling.activations.swish import simple_swish
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
def swish(features): def simple_swish(features):
"""Computes the Swish activation function. """Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage. The tf.nn.swish operation uses a custom gradient to reduce memory usage.
...@@ -40,3 +40,36 @@ def swish(features): ...@@ -40,3 +40,36 @@ def swish(features):
""" """
features = tf.convert_to_tensor(features) features = tf.convert_to_tensor(features)
return features * tf.nn.sigmoid(features) return features * tf.nn.sigmoid(features)
@tf.keras.utils.register_keras_serializable(package='Text')
def hard_swish(features):
"""Computes a hard version of the swish function.
This operation can be used to reduce computational cost and improve
quantization for edge devices.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.)
@tf.keras.utils.register_keras_serializable(package='Text')
def identity(features):
"""Computes the identity function.
Useful for helping in quantization.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return tf.identity(features)
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -27,9 +28,20 @@ from official.modeling import activations ...@@ -27,9 +28,20 @@ from official.modeling import activations
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class CustomizedSwishTest(keras_parameterized.TestCase): class CustomizedSwishTest(keras_parameterized.TestCase):
def test_gelu(self): def _hard_swish_np(self, x):
customized_swish_data = activations.swish([[.25, 0, -.25], [-1, -2, 3]]) x = np.float32(x)
swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]]) return x * np.clip(x + 3, 0, 6) / 6
def test_simple_swish(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_swish_data = activations.simple_swish(features)
swish_data = tf.nn.swish(features)
self.assertAllClose(customized_swish_data, swish_data)
def test_hard_swish(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_swish_data = activations.hard_swish(features)
swish_data = self._hard_swish_np(features)
self.assertAllClose(customized_swish_data, swish_data) self.assertAllClose(customized_swish_data, swish_data)
......
...@@ -92,7 +92,9 @@ def get_activation(identifier): ...@@ -92,7 +92,9 @@ def get_activation(identifier):
if isinstance(identifier, six.string_types): if isinstance(identifier, six.string_types):
name_to_fn = { name_to_fn = {
"gelu": activations.gelu, "gelu": activations.gelu,
"custom_swish": activations.swish, "simple_swish": activations.simple_swish,
"hard_swish": activations.hard_swish,
"identity": activations.identity,
} }
identifier = str(identifier).lower() identifier = str(identifier).lower()
if identifier in name_to_fn: if identifier in name_to_fn:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment