Commit 6a761cc8 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 355221779
parent 15430ccc
...@@ -74,6 +74,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -74,6 +74,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
out_filters, out_filters,
se_ratio, se_ratio,
divisible_by=1, divisible_by=1,
use_3d_input=False,
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
...@@ -89,6 +90,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -89,6 +90,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
excitation layer. excitation layer.
divisible_by: `int` ensures all inner dimensions are divisible by this divisible_by: `int` ensures all inner dimensions are divisible by this
number. number.
use_3d_input: `bool` 2D image or 3D input type.
kernel_initializer: kernel_initializer for convolutional layers. kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D. kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None. Default to None.
...@@ -105,15 +107,22 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -105,15 +107,22 @@ class SqueezeExcitation(tf.keras.layers.Layer):
self._out_filters = out_filters self._out_filters = out_filters
self._se_ratio = se_ratio self._se_ratio = se_ratio
self._divisible_by = divisible_by self._divisible_by = divisible_by
self._use_3d_input = use_3d_input
self._activation = activation self._activation = activation
self._gating_activation = gating_activation self._gating_activation = gating_activation
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer self._bias_regularizer = bias_regularizer
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
self._spatial_axis = [1, 2] if not use_3d_input:
self._spatial_axis = [1, 2]
else:
self._spatial_axis = [1, 2, 3]
else: else:
self._spatial_axis = [2, 3] if not use_3d_input:
self._spatial_axis = [2, 3]
else:
self._spatial_axis = [2, 3, 4]
self._activation_fn = tf_utils.get_activation(activation) self._activation_fn = tf_utils.get_activation(activation)
self._gating_activation_fn = tf_utils.get_activation(gating_activation) self._gating_activation_fn = tf_utils.get_activation(gating_activation)
...@@ -150,6 +159,7 @@ class SqueezeExcitation(tf.keras.layers.Layer): ...@@ -150,6 +159,7 @@ class SqueezeExcitation(tf.keras.layers.Layer):
'out_filters': self._out_filters, 'out_filters': self._out_filters,
'se_ratio': self._se_ratio, 'se_ratio': self._se_ratio,
'divisible_by': self._divisible_by, 'divisible_by': self._divisible_by,
'use_3d_input': self._use_3d_input,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment