Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1587d2db
Commit
1587d2db
authored
Aug 26, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 470217720
parent
a55cf4d3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
101 additions
and
4 deletions
+101
-4
official/modeling/activations/__init__.py
official/modeling/activations/__init__.py
+1
-0
official/modeling/activations/mish.py
official/modeling/activations/mish.py
+38
-0
official/modeling/activations/mish_test.py
official/modeling/activations/mish_test.py
+32
-0
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+11
-3
official/modeling/tf_utils_test.py
official/modeling/tf_utils_test.py
+18
-0
official/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
...volumetric_models/modeling/heads/segmentation_heads_3d.py
+1
-1
No files found.
official/modeling/activations/__init__.py
View file @
1587d2db
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Activations package definition."""
"""Activations package definition."""
from
official.modeling.activations.gelu
import
gelu
from
official.modeling.activations.gelu
import
gelu
from
official.modeling.activations.mish
import
mish
from
official.modeling.activations.relu
import
relu6
from
official.modeling.activations.relu
import
relu6
from
official.modeling.activations.sigmoid
import
hard_sigmoid
from
official.modeling.activations.sigmoid
import
hard_sigmoid
from
official.modeling.activations.swish
import
hard_swish
from
official.modeling.activations.swish
import
hard_swish
...
...
official/modeling/activations/mish.py
0 → 100644
View file @
1587d2db
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Self Regularized Non-Monotonic Activation Function."""
import
tensorflow
as
tf
from
tensorflow_addons.utils
import
types
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
mish
(
x
:
types
.
TensorLike
)
->
tf
.
Tensor
:
"""Mish activation function.
Mish: A Self Regularized Non-Monotonic Activation Function
https://arxiv.org/pdf/1908.08681.pdf
Mish(x) = x * tanh(ln(1+e^x))
Args:
x: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
x
=
tf
.
convert_to_tensor
(
x
)
return
x
*
tf
.
tanh
(
tf
.
nn
.
softplus
(
x
))
official/modeling/activations/mish_test.py
0 → 100644
View file @
1587d2db
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Mish activation."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
MishTest
(
keras_parameterized
.
TestCase
):
def
test_mish
(
self
):
x
=
tf
.
constant
([
1.0
,
0.0
])
self
.
assertAllClose
([
0.86509839
,
0.0
],
activations
.
mish
(
x
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/tf_utils.py
View file @
1587d2db
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""Common TF utilities."""
"""Common TF utilities."""
import
functools
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -82,19 +83,22 @@ def is_special_none_tensor(tensor):
...
@@ -82,19 +83,22 @@ def is_special_none_tensor(tensor):
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
def
get_activation
(
identifier
,
use_keras_layer
=
False
):
def
get_activation
(
identifier
,
use_keras_layer
=
False
,
**
kwargs
):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
"""Maps a
n
identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
It checks string first and if it is one of customized activation not in TF,
It checks string first and if it is one of customized activation not in TF,
the corresponding activation will be returned. For non-customized activation
the corresponding activation will be returned. For non-customized activation
names and callable identifiers, always fallback to tf.keras.activations.get.
names and callable identifiers, always fallback to tf.keras.activations.get.
Prefers using keras layers when use_keras_layer=True. Now it only supports
Prefers using keras layers when use_keras_layer=True. Now it only supports
'relu', 'linear', 'identity', 'swish'.
'relu', 'linear', 'identity', 'swish'
, 'mish', 'leaky_relu', and 'gelu'
.
Args:
Args:
identifier: String name of the activation function or callable.
identifier: String name of the activation function or callable.
use_keras_layer: If True, use keras layer if identifier is allow-listed.
use_keras_layer: If True, use keras layer if identifier is allow-listed.
**kwargs: Keyword arguments to use to instantiate an activation function.
Available only for 'leaky_relu' and 'gelu' when using keras layers.
For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1)
Returns:
Returns:
A Python function corresponding to the activation function or a keras
A Python function corresponding to the activation function or a keras
...
@@ -110,8 +114,11 @@ def get_activation(identifier, use_keras_layer=False):
...
@@ -110,8 +114,11 @@ def get_activation(identifier, use_keras_layer=False):
"swish"
:
"swish"
,
"swish"
:
"swish"
,
"sigmoid"
:
"sigmoid"
,
"sigmoid"
:
"sigmoid"
,
"relu6"
:
tf
.
nn
.
relu6
,
"relu6"
:
tf
.
nn
.
relu6
,
"leaky_relu"
:
functools
.
partial
(
tf
.
nn
.
leaky_relu
,
**
kwargs
),
"hard_swish"
:
activations
.
hard_swish
,
"hard_swish"
:
activations
.
hard_swish
,
"hard_sigmoid"
:
activations
.
hard_sigmoid
,
"hard_sigmoid"
:
activations
.
hard_sigmoid
,
"mish"
:
activations
.
mish
,
"gelu"
:
functools
.
partial
(
tf
.
nn
.
gelu
,
**
kwargs
),
}
}
if
identifier
in
keras_layer_allowlist
:
if
identifier
in
keras_layer_allowlist
:
return
tf
.
keras
.
layers
.
Activation
(
keras_layer_allowlist
[
identifier
])
return
tf
.
keras
.
layers
.
Activation
(
keras_layer_allowlist
[
identifier
])
...
@@ -122,6 +129,7 @@ def get_activation(identifier, use_keras_layer=False):
...
@@ -122,6 +129,7 @@ def get_activation(identifier, use_keras_layer=False):
"relu6"
:
activations
.
relu6
,
"relu6"
:
activations
.
relu6
,
"hard_sigmoid"
:
activations
.
hard_sigmoid
,
"hard_sigmoid"
:
activations
.
hard_sigmoid
,
"identity"
:
activations
.
identity
,
"identity"
:
activations
.
identity
,
"mish"
:
activations
.
mish
,
}
}
if
identifier
in
name_to_fn
:
if
identifier
in
name_to_fn
:
return
tf
.
keras
.
activations
.
get
(
name_to_fn
[
identifier
])
return
tf
.
keras
.
activations
.
get
(
name_to_fn
[
identifier
])
...
...
official/modeling/tf_utils_test.py
View file @
1587d2db
...
@@ -84,6 +84,24 @@ class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -84,6 +84,24 @@ class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
for
gradient
in
per_replica_gradients
.
values
:
for
gradient
in
per_replica_gradients
.
values
:
self
.
assertAllClose
(
gradient
,
num_cores
*
tf
.
ones
(
shape
))
self
.
assertAllClose
(
gradient
,
num_cores
*
tf
.
ones
(
shape
))
@
parameterized
.
parameters
((
'relu'
,
True
),
(
'relu'
,
False
),
(
'leaky_relu'
,
False
),
(
'leaky_relu'
,
True
),
(
'mish'
,
True
),
(
'mish'
,
False
),
(
'gelu'
,
True
))
def
test_get_activations
(
self
,
name
,
use_keras_layer
):
fn
=
tf_utils
.
get_activation
(
name
,
use_keras_layer
)
self
.
assertIsNotNone
(
fn
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_get_leaky_relu_layer
(
self
,
strategy
):
@
tf
.
function
def
forward
(
x
):
fn
=
tf_utils
.
get_activation
(
'leaky_relu'
,
use_keras_layer
=
True
,
alpha
=
0.1
)
return
strategy
.
run
(
fn
,
args
=
(
x
,)).
values
[
0
]
got
=
forward
(
tf
.
constant
([
-
1
]))
self
.
assertAllClose
(
got
,
tf
.
constant
([
-
0.1
]))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
View file @
1587d2db
...
@@ -88,7 +88,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
...
@@ -88,7 +88,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
self
.
_bn_axis
=
-
1
self
.
_bn_axis
=
-
1
else
:
else
:
self
.
_bn_axis
=
1
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
self
.
_activation
=
tf_utils
.
get_activation
(
activation
,
use_keras_layer
=
True
)
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
Sequence
[
tf
.
TensorShape
]]):
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
Sequence
[
tf
.
TensorShape
]]):
"""Creates the variables of the segmentation head."""
"""Creates the variables of the segmentation head."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment