Commit 77710731 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Moves activations to official/modeling

Adds a swish activation without customized gradients.

PiperOrigin-RevId: 272029817
parent f52b8c93
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Activations package definition."""
from official.modeling.activations.gelu import gelu
from official.modeling.activations.swish import swish
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def gelu(x):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the Gaussian error linear unit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling.activations import gelu
@keras_parameterized.run_all_keras_modes
class GeluTest(keras_parameterized.TestCase):
def test_gelu(self):
expected_data = [[0.14967535, 0., -0.10032465],
[-0.15880796, -0.04540223, 2.9963627]]
gelu_data = gelu.gelu([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(expected_data, gelu_data)
if __name__ == '__main__':
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Customized Swish activation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
def swish(features):
"""Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
Since saving custom gradients in SavedModel is currently not supported, and
one would not be able to use an exported TF-Hub module for fine-tuning, we
provide this wrapper that can allow to select whether to use the native
TensorFlow swish operation, or whether to use a customized operation that
has uses default TensorFlow gradient computation.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features = tf.convert_to_tensor(features)
return features * tf.nn.sigmoid(features)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the customized Swish activation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.modeling.activations import swish
@keras_parameterized.run_all_keras_modes
class CustomizedSwishTest(keras_parameterized.TestCase):
def test_gelu(self):
customized_swish_data = swish.swish([[.25, 0, -.25], [-1, -2, 3]])
swish_data = tf.nn.swish([[.25, 0, -.25], [-1, -2, 3]])
self.assertAllClose(customized_swish_data, swish_data)
if __name__ == '__main__':
tf.test.main()
...@@ -18,10 +18,11 @@ from __future__ import absolute_import ...@@ -18,10 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import six import six
import tensorflow as tf import tensorflow as tf
from official.modeling import activations
def pack_inputs(inputs): def pack_inputs(inputs):
"""Pack a list of `inputs` tensors to a tuple. """Pack a list of `inputs` tensors to a tuple.
...@@ -74,55 +75,29 @@ def is_special_none_tensor(tensor): ...@@ -74,55 +75,29 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def gelu(x): # TODO(hongkuny): consider moving custom string-map lookup to keras api.
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf = 0.5 * (1.0 + tf.tanh(
(math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
return x * cdf
def get_activation(identifier): def get_activation(identifier):
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
Args: Args:
identifier: String name of the activation function. identifier: String name of the activation function or callable.
Returns: Returns:
A Python function corresponding to the activation function. If A Python function corresponding to the activation function.
`identifier` is None, empty, or "linear", this will return None.
If `identifier` is not a string, it will return `identifier`.
Raises:
ValueError: The `identifier` does not correspond to a known
activation.
""" """
if identifier is None: if isinstance(identifier, six.string_types):
return None
elif isinstance(identifier, six.string_types):
name_to_fn = { name_to_fn = {
"linear": None, "gelu": activations.gelu,
"relu": tf.nn.relu, "custom_swish": activations.swish,
"gelu": gelu,
"tanh": tf.nn.tanh,
} }
identifier = str(identifier).lower() identifier = str(identifier).lower()
if identifier not in name_to_fn: if identifier in name_to_fn:
raise ValueError("Unsupported activation function: %s" % (identifier)) return tf.keras.activations.get(name_to_fn[identifier])
return name_to_fn[identifier] return tf.keras.activations.get(identifier)
elif callable(identifier):
return identifier
else:
raise ValueError("Could not interpret activation "
"function identifier: %s" % (identifier))
def get_shape_list(tensor, expected_rank=None, name=None): def get_shape_list(tensor, expected_rank=None, name=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment