Commit 1587d2db authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 470217720
parent a55cf4d3
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Activations package definition.""" """Activations package definition."""
from official.modeling.activations.gelu import gelu from official.modeling.activations.gelu import gelu
from official.modeling.activations.mish import mish
from official.modeling.activations.relu import relu6 from official.modeling.activations.relu import relu6
from official.modeling.activations.sigmoid import hard_sigmoid from official.modeling.activations.sigmoid import hard_sigmoid
from official.modeling.activations.swish import hard_swish from official.modeling.activations.swish import hard_swish
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Self Regularized Non-Monotonic Activation Function."""
import tensorflow as tf
from tensorflow_addons.utils import types
@tf.keras.utils.register_keras_serializable(package='Text')
def mish(x: types.TensorLike) -> tf.Tensor:
"""Mish activation function.
Mish: A Self Regularized Non-Monotonic Activation Function
https://arxiv.org/pdf/1908.08681.pdf
Mish(x) = x * tanh(ln(1+e^x))
Args:
x: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
x = tf.convert_to_tensor(x)
return x * tf.tanh(tf.nn.softplus(x))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Mish activation."""
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class MishTest(keras_parameterized.TestCase):
def test_mish(self):
x = tf.constant([1.0, 0.0])
self.assertAllClose([0.86509839, 0.0], activations.mish(x))
if __name__ == '__main__':
tf.test.main()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Common TF utilities.""" """Common TF utilities."""
import functools
import six import six
import tensorflow as tf import tensorflow as tf
...@@ -82,19 +83,22 @@ def is_special_none_tensor(tensor): ...@@ -82,19 +83,22 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def get_activation(identifier, use_keras_layer=False): def get_activation(identifier, use_keras_layer=False, **kwargs):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. """Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF, It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get. names and callable identifiers, always fallback to tf.keras.activations.get.
Prefers using keras layers when use_keras_layer=True. Now it only supports Prefers using keras layers when use_keras_layer=True. Now it only supports
'relu', 'linear', 'identity', 'swish'. 'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'.
Args: Args:
identifier: String name of the activation function or callable. identifier: String name of the activation function or callable.
use_keras_layer: If True, use keras layer if identifier is allow-listed. use_keras_layer: If True, use keras layer if identifier is allow-listed.
**kwargs: Keyword arguments to use to instantiate an activation function.
Available only for 'leaky_relu' and 'gelu' when using keras layers.
For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1)
Returns: Returns:
A Python function corresponding to the activation function or a keras A Python function corresponding to the activation function or a keras
...@@ -110,8 +114,11 @@ def get_activation(identifier, use_keras_layer=False): ...@@ -110,8 +114,11 @@ def get_activation(identifier, use_keras_layer=False):
"swish": "swish", "swish": "swish",
"sigmoid": "sigmoid", "sigmoid": "sigmoid",
"relu6": tf.nn.relu6, "relu6": tf.nn.relu6,
"leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs),
"hard_swish": activations.hard_swish, "hard_swish": activations.hard_swish,
"hard_sigmoid": activations.hard_sigmoid, "hard_sigmoid": activations.hard_sigmoid,
"mish": activations.mish,
"gelu": functools.partial(tf.nn.gelu, **kwargs),
} }
if identifier in keras_layer_allowlist: if identifier in keras_layer_allowlist:
return tf.keras.layers.Activation(keras_layer_allowlist[identifier]) return tf.keras.layers.Activation(keras_layer_allowlist[identifier])
...@@ -122,6 +129,7 @@ def get_activation(identifier, use_keras_layer=False): ...@@ -122,6 +129,7 @@ def get_activation(identifier, use_keras_layer=False):
"relu6": activations.relu6, "relu6": activations.relu6,
"hard_sigmoid": activations.hard_sigmoid, "hard_sigmoid": activations.hard_sigmoid,
"identity": activations.identity, "identity": activations.identity,
"mish": activations.mish,
} }
if identifier in name_to_fn: if identifier in name_to_fn:
return tf.keras.activations.get(name_to_fn[identifier]) return tf.keras.activations.get(name_to_fn[identifier])
......
...@@ -84,6 +84,24 @@ class TFUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -84,6 +84,24 @@ class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
for gradient in per_replica_gradients.values: for gradient in per_replica_gradients.values:
self.assertAllClose(gradient, num_cores * tf.ones(shape)) self.assertAllClose(gradient, num_cores * tf.ones(shape))
@parameterized.parameters(('relu', True), ('relu', False),
('leaky_relu', False), ('leaky_relu', True),
('mish', True), ('mish', False), ('gelu', True))
def test_get_activations(self, name, use_keras_layer):
fn = tf_utils.get_activation(name, use_keras_layer)
self.assertIsNotNone(fn)
@combinations.generate(all_strategy_combinations())
def test_get_leaky_relu_layer(self, strategy):
@tf.function
def forward(x):
fn = tf_utils.get_activation(
'leaky_relu', use_keras_layer=True, alpha=0.1)
return strategy.run(fn, args=(x,)).values[0]
got = forward(tf.constant([-1]))
self.assertAllClose(got, tf.constant([-0.1]))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -88,7 +88,7 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -88,7 +88,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
self._bn_axis = -1 self._bn_axis = -1
else: else:
self._bn_axis = 1 self._bn_axis = 1
self._activation = tf_utils.get_activation(activation) self._activation = tf_utils.get_activation(activation, use_keras_layer=True)
def build(self, input_shape: Union[tf.TensorShape, Sequence[tf.TensorShape]]): def build(self, input_shape: Union[tf.TensorShape, Sequence[tf.TensorShape]]):
"""Creates the variables of the segmentation head.""" """Creates the variables of the segmentation head."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment