Commit 1587d2db authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 470217720
parent a55cf4d3
......@@ -14,6 +14,7 @@
"""Activations package definition."""
from official.modeling.activations.gelu import gelu
from official.modeling.activations.mish import mish
from official.modeling.activations.relu import relu6
from official.modeling.activations.sigmoid import hard_sigmoid
from official.modeling.activations.swish import hard_swish
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Self Regularized Non-Monotonic Activation Function."""
import tensorflow as tf
from tensorflow_addons.utils import types
@tf.keras.utils.register_keras_serializable(package='Text')
def mish(x: types.TensorLike) -> tf.Tensor:
"""Mish activation function.
Mish: A Self Regularized Non-Monotonic Activation Function
https://arxiv.org/pdf/1908.08681.pdf
Mish(x) = x * tanh(ln(1+e^x))
Args:
x: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
x = tf.convert_to_tensor(x)
return x * tf.tanh(tf.nn.softplus(x))
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Mish activation."""
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
@keras_parameterized.run_all_keras_modes
class MishTest(keras_parameterized.TestCase):
def test_mish(self):
x = tf.constant([1.0, 0.0])
self.assertAllClose([0.86509839, 0.0], activations.mish(x))
if __name__ == '__main__':
tf.test.main()
......@@ -14,6 +14,7 @@
"""Common TF utilities."""
import functools
import six
import tensorflow as tf
......@@ -82,19 +83,22 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def get_activation(identifier, use_keras_layer=False):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
def get_activation(identifier, use_keras_layer=False, **kwargs):
"""Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
Prefers using keras layers when use_keras_layer=True. Now it only supports
'relu', 'linear', 'identity', 'swish'.
'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'.
Args:
identifier: String name of the activation function or callable.
use_keras_layer: If True, use keras layer if identifier is allow-listed.
**kwargs: Keyword arguments to use to instantiate an activation function.
Available only for 'leaky_relu' and 'gelu' when using keras layers.
For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1)
Returns:
A Python function corresponding to the activation function or a keras
......@@ -110,8 +114,11 @@ def get_activation(identifier, use_keras_layer=False):
"swish": "swish",
"sigmoid": "sigmoid",
"relu6": tf.nn.relu6,
"leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs),
"hard_swish": activations.hard_swish,
"hard_sigmoid": activations.hard_sigmoid,
"mish": activations.mish,
"gelu": functools.partial(tf.nn.gelu, **kwargs),
}
if identifier in keras_layer_allowlist:
return tf.keras.layers.Activation(keras_layer_allowlist[identifier])
......@@ -122,6 +129,7 @@ def get_activation(identifier, use_keras_layer=False):
"relu6": activations.relu6,
"hard_sigmoid": activations.hard_sigmoid,
"identity": activations.identity,
"mish": activations.mish,
}
if identifier in name_to_fn:
return tf.keras.activations.get(name_to_fn[identifier])
......
......@@ -84,6 +84,24 @@ class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
for gradient in per_replica_gradients.values:
self.assertAllClose(gradient, num_cores * tf.ones(shape))
@parameterized.parameters(('relu', True), ('relu', False),
('leaky_relu', False), ('leaky_relu', True),
('mish', True), ('mish', False), ('gelu', True))
def test_get_activations(self, name, use_keras_layer):
fn = tf_utils.get_activation(name, use_keras_layer)
self.assertIsNotNone(fn)
@combinations.generate(all_strategy_combinations())
def test_get_leaky_relu_layer(self, strategy):
@tf.function
def forward(x):
fn = tf_utils.get_activation(
'leaky_relu', use_keras_layer=True, alpha=0.1)
return strategy.run(fn, args=(x,)).values[0]
got = forward(tf.constant([-1]))
self.assertAllClose(got, tf.constant([-0.1]))
if __name__ == '__main__':
tf.test.main()
......@@ -88,7 +88,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
self._activation = tf_utils.get_activation(activation, use_keras_layer=True)
def build(self, input_shape: Union[tf.TensorShape, Sequence[tf.TensorShape]]):
"""Creates the variables of the segmentation head."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment