Commit 5323d280 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410344878
parent dbf19582
...@@ -22,6 +22,7 @@ from absl import logging ...@@ -22,6 +22,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.beta.modeling.decoders import factory from official.vision.beta.modeling.decoders import factory
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
...@@ -165,12 +166,7 @@ class NASFPN(tf.keras.Model): ...@@ -165,12 +166,7 @@ class NASFPN(tf.keras.Model):
'momentum': self._config_dict['norm_momentum'], 'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'], 'epsilon': self._config_dict['norm_epsilon'],
} }
if activation == 'relu': self._activation = tf_utils.get_activation(activation)
self._activation = tf.nn.relu
elif activation == 'swish':
self._activation = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
# Gets input feature pyramid from backbone. # Gets input feature pyramid from backbone.
inputs = self._build_input_pyramid(input_specs, min_level) inputs = self._build_input_pyramid(input_specs, min_level)
...@@ -238,7 +234,11 @@ class NASFPN(tf.keras.Model): ...@@ -238,7 +234,11 @@ class NASFPN(tf.keras.Model):
# dtype mismatch when one input (by default float32 dtype) does not meet all # dtype mismatch when one input (by default float32 dtype) does not meet all
# the above conditions and is output unchanged, while other inputs are # the above conditions and is output unchanged, while other inputs are
# processed to have different dtype, e.g., using bfloat16 on TPU. # processed to have different dtype, e.g., using bfloat16 on TPU.
return tf.cast(x, dtype=tf.keras.layers.Layer().dtype_policy.compute_dtype) if tf.keras.layers.Layer().dtype_policy.compute_dtype is not None:
return tf.cast(
x, dtype=tf.keras.layers.Layer().dtype_policy.compute_dtype)
else:
return x
def _global_attention(self, feat0, feat1): def _global_attention(self, feat0, feat1):
m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True) m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment