Commit e37e8049 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Modify EfficientNet to support the functional subclassed model.

PiperOrigin-RevId: 277179351
parent f0141859
......@@ -14,4 +14,6 @@
# ==============================================================================
"""Activations package definition."""
from official.modeling.activations.gelu import gelu
from official.modeling.activations.swish import swish
from official.modeling.activations.swish import hard_swish
from official.modeling.activations.swish import identity
from official.modeling.activations.swish import simple_swish
......@@ -22,7 +22,7 @@ import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def swish(features):
def simple_swish(features):
"""Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
......@@ -40,3 +40,36 @@ def swish(features):
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.sigmoid(features)
@tf.keras.utils.register_keras_serializable(package='Text')
def hard_swish(features):
"""Computes a hard version of the swish function.
This operation can be used to reduce computational cost and improve
quantization for edge devices.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.relu6(features + tf.constant(3.)) * (1. / 6.)
@tf.keras.utils.register_keras_serializable(package='Text')
def identity(features):
"""Computes the identity function.
Useful for helping in quantization.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return tf.identity(features)
......@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
......@@ -27,9 +28,20 @@ from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class CustomizedSwishTest(keras_parameterized.TestCase):
def test_gelu(self):
customized_swish_data = activations.swish([[.25, 0, -.25], [-1, -2, 3]])
swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]])
def _hard_swish_np(self, x):
x = np.float32(x)
return x * np.clip(x + 3, 0, 6) / 6
def test_simple_swish(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_swish_data = activations.simple_swish(features)
swish_data = tf.nn.swish(features)
self.assertAllClose(customized_swish_data, swish_data)
def test_hard_swish(self):
features = [[.25, 0, -.25], [-1, -2, 3]]
customized_swish_data = activations.hard_swish(features)
swish_data = self._hard_swish_np(features)
self.assertAllClose(customized_swish_data, swish_data)
......
......@@ -92,7 +92,9 @@ def get_activation(identifier):
if isinstance(identifier, six.string_types):
name_to_fn = {
"gelu": activations.gelu,
"custom_swish": activations.swish,
"simple_swish": activations.simple_swish,
"hard_swish": activations.hard_swish,
"identity": activations.identity,
}
identifier = str(identifier).lower()
if identifier in name_to_fn:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment