Unverified Commit ffb5bb61 authored by Liangzhe's avatar Liangzhe Committed by GitHub
Browse files

Make quantizable_separable_conv2d initializer explicitly configurable....

Make quantizable_separable_conv2d initializer explicitly configurable. (otherwise we could use slim.arg_scope to config the initializer) (#8661)

PiperOrigin-RevId: 315732759
parent 9dadc325
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""BottleneckConvLSTMCell implementation.""" """BottleneckConvLSTMCell implementation."""
import functools
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim import tf_slim as slim
from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.framework.python.ops import variables as contrib_variables
import lstm_object_detection.lstm.utils as lstm_utils import lstm_object_detection.lstm.utils as lstm_utils
...@@ -285,7 +287,8 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -285,7 +287,8 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
output_bottleneck=False, output_bottleneck=False,
pre_bottleneck=False, pre_bottleneck=False,
is_quantized=False, is_quantized=False,
visualize_gates=False): visualize_gates=False,
conv_op_overrides=None):
"""Initialize the basic LSTM cell. """Initialize the basic LSTM cell.
Args: Args:
...@@ -311,6 +314,10 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -311,6 +314,10 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
quantization friendly concat and separable_conv2d ops. quantization friendly concat and separable_conv2d ops.
visualize_gates: if True, add histogram summaries of all gates and outputs visualize_gates: if True, add histogram summaries of all gates and outputs
to tensorboard to tensorboard
conv_op_overrides: A list of convolutional operations that override the
'bottleneck' and 'convolution' layers before lstm gates. If None, the
original implementation of seperable_conv will be used. The length of
the list should be two.
Raises: Raises:
ValueError: when both clip_state and scale_state are enabled. ValueError: when both clip_state and scale_state are enabled.
...@@ -336,6 +343,10 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -336,6 +343,10 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
self._is_quantized = is_quantized self._is_quantized = is_quantized
for dim in self._output_size: for dim in self._output_size:
self._param_count *= dim self._param_count *= dim
self._conv_op_overrides = conv_op_overrides
if self._conv_op_overrides and len(self._conv_op_overrides) != 2:
raise ValueError('Bottleneck and Convolutional layer should be overriden'
'together')
@property @property
def state_size(self): def state_size(self):
...@@ -405,23 +416,26 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -405,23 +416,26 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
if self._pre_bottleneck: if self._pre_bottleneck:
bottleneck = inputs_list[k] bottleneck = inputs_list[k]
else: else:
if self._conv_op_overrides:
bottleneck_fn = self._conv_op_overrides[0]
else:
bottleneck_fn = functools.partial(
lstm_utils.quantizable_separable_conv2d,
kernel_size=self._filter_size,
activation_fn=self._activation)
if self._use_batch_norm: if self._use_batch_norm:
b_x = lstm_utils.quantizable_separable_conv2d( b_x = bottleneck_fn(
inputs, inputs=inputs,
self._num_units // self._groups, num_outputs=self._num_units // self._groups,
self._filter_size,
is_quantized=self._is_quantized, is_quantized=self._is_quantized,
depth_multiplier=1, depth_multiplier=1,
activation_fn=None,
normalizer_fn=None, normalizer_fn=None,
scope='bottleneck_%d_x' % k) scope='bottleneck_%d_x' % k)
b_h = lstm_utils.quantizable_separable_conv2d( b_h = bottleneck_fn(
h_list[k], inputs=h_list[k],
self._num_units // self._groups, num_outputs=self._num_units // self._groups,
self._filter_size,
is_quantized=self._is_quantized, is_quantized=self._is_quantized,
depth_multiplier=1, depth_multiplier=1,
activation_fn=None,
normalizer_fn=None, normalizer_fn=None,
scope='bottleneck_%d_h' % k) scope='bottleneck_%d_h' % k)
b_x = slim.batch_norm( b_x = slim.batch_norm(
...@@ -445,24 +459,26 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -445,24 +459,26 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
is_training=False, is_training=False,
is_quantized=self._is_quantized, is_quantized=self._is_quantized,
scope='bottleneck_%d/quantized_concat' % k) scope='bottleneck_%d/quantized_concat' % k)
bottleneck = bottleneck_fn(
bottleneck = lstm_utils.quantizable_separable_conv2d( inputs=bottleneck_concat,
bottleneck_concat, num_outputs=self._num_units // self._groups,
self._num_units // self._groups,
self._filter_size,
is_quantized=self._is_quantized, is_quantized=self._is_quantized,
depth_multiplier=1, depth_multiplier=1,
activation_fn=self._activation,
normalizer_fn=None, normalizer_fn=None,
scope='bottleneck_%d' % k) scope='bottleneck_%d' % k)
concat = lstm_utils.quantizable_separable_conv2d( if self._conv_op_overrides:
bottleneck, conv_fn = self._conv_op_overrides[1]
4 * self._num_units // self._groups, else:
self._filter_size, conv_fn = functools.partial(
lstm_utils.quantizable_separable_conv2d,
kernel_size=self._filter_size,
activation_fn=None)
concat = conv_fn(
inputs=bottleneck,
num_outputs=4 * self._num_units // self._groups,
is_quantized=self._is_quantized, is_quantized=self._is_quantized,
depth_multiplier=1, depth_multiplier=1,
activation_fn=None,
normalizer_fn=None, normalizer_fn=None,
scope='concat_conv_%d' % k) scope='concat_conv_%d' % k)
...@@ -490,14 +506,6 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -490,14 +506,6 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
is_quantized=self._is_quantized, is_quantized=self._is_quantized,
scope='forget_gate_%d/add_quant' % k) scope='forget_gate_%d/add_quant' % k)
f_act = tf.sigmoid(f_add) f_act = tf.sigmoid(f_add)
# The quantization range is fixed for the sigmoid to ensure that zero
# is exactly representable.
f_act = lstm_utils.fixed_quantize_op(
f_act,
fixed_min=0.0,
fixed_max=1.0,
is_quantized=self._is_quantized,
scope='forget_gate_%d/act_quant' % k)
a = c_list[k] * f_act a = c_list[k] * f_act
a = lstm_utils.quantize_op( a = lstm_utils.quantize_op(
...@@ -507,14 +515,6 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -507,14 +515,6 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
scope='forget_gate_%d/mul_quant' % k) scope='forget_gate_%d/mul_quant' % k)
i_act = tf.sigmoid(i) i_act = tf.sigmoid(i)
# The quantization range is fixed for the sigmoid to ensure that zero
# is exactly representable.
i_act = lstm_utils.fixed_quantize_op(
i_act,
fixed_min=0.0,
fixed_max=1.0,
is_quantized=self._is_quantized,
scope='input_gate_%d/act_quant' % k)
j_act = self._activation(j) j_act = self._activation(j)
# The quantization range is fixed for the relu6 to ensure that zero # The quantization range is fixed for the relu6 to ensure that zero
...@@ -567,14 +567,6 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell): ...@@ -567,14 +567,6 @@ class GroupedConvLSTMCell(contrib_rnn.RNNCell):
scope='new_c_%d/act_quant' % k) scope='new_c_%d/act_quant' % k)
o_act = tf.sigmoid(o) o_act = tf.sigmoid(o)
# The quantization range is fixed for the sigmoid to ensure that zero
# is exactly representable.
o_act = lstm_utils.fixed_quantize_op(
o_act,
fixed_min=0.0,
fixed_max=1.0,
is_quantized=self._is_quantized,
scope='output_%d/act_quant' % k)
new_h = new_c_act * o_act new_h = new_c_act * o_act
# The quantization range is fixed since it is input to a concat. # The quantization range is fixed since it is input to a concat.
......
...@@ -120,6 +120,8 @@ def quantizable_separable_conv2d(inputs, ...@@ -120,6 +120,8 @@ def quantizable_separable_conv2d(inputs,
stride=1, stride=1,
activation_fn=tf.nn.relu6, activation_fn=tf.nn.relu6,
normalizer_fn=None, normalizer_fn=None,
weights_initializer=None,
pointwise_initializer=None,
scope=None): scope=None):
"""Quantization friendly backward compatible separable conv2d. """Quantization friendly backward compatible separable conv2d.
...@@ -145,6 +147,8 @@ def quantizable_separable_conv2d(inputs, ...@@ -145,6 +147,8 @@ def quantizable_separable_conv2d(inputs,
activation_fn: Activation function. The default value is a ReLU function. activation_fn: Activation function. The default value is a ReLU function.
Explicitly set it to None to skip it and maintain a linear activation. Explicitly set it to None to skip it and maintain a linear activation.
normalizer_fn: Normalization function to use instead of biases. normalizer_fn: Normalization function to use instead of biases.
weights_initializer: An initializer for the depthwise weights.
pointwise_initializer: An initializer for the pointwise weights.
scope: Optional scope for variable_scope. scope: Optional scope for variable_scope.
Returns: Returns:
...@@ -160,6 +164,8 @@ def quantizable_separable_conv2d(inputs, ...@@ -160,6 +164,8 @@ def quantizable_separable_conv2d(inputs,
activation_fn=None, activation_fn=None,
normalizer_fn=None, normalizer_fn=None,
biases_initializer=None, biases_initializer=None,
weights_initializer=weights_initializer,
pointwise_initializer=None,
scope=scope) scope=scope)
outputs = contrib_layers.bias_add( outputs = contrib_layers.bias_add(
outputs, trainable=True, scope='%s_bias' % scope) outputs, trainable=True, scope='%s_bias' % scope)
...@@ -169,6 +175,7 @@ def quantizable_separable_conv2d(inputs, ...@@ -169,6 +175,7 @@ def quantizable_separable_conv2d(inputs,
activation_fn=activation_fn, activation_fn=activation_fn,
stride=stride, stride=stride,
normalizer_fn=normalizer_fn, normalizer_fn=normalizer_fn,
weights_initializer=pointwise_initializer,
scope=scope) scope=scope)
else: else:
outputs = contrib_layers.separable_conv2d( outputs = contrib_layers.separable_conv2d(
...@@ -179,6 +186,8 @@ def quantizable_separable_conv2d(inputs, ...@@ -179,6 +186,8 @@ def quantizable_separable_conv2d(inputs,
stride=stride, stride=stride,
activation_fn=activation_fn, activation_fn=activation_fn,
normalizer_fn=normalizer_fn, normalizer_fn=normalizer_fn,
weights_initializer=weights_initializer,
pointwise_initializer=pointwise_initializer,
scope=scope) scope=scope)
return outputs return outputs
......
...@@ -82,7 +82,7 @@ class FakeLSTMFeatureExtractor( ...@@ -82,7 +82,7 @@ class FakeLSTMFeatureExtractor(
min_depth=self._min_depth, min_depth=self._min_depth,
insert_1x1_conv=True, insert_1x1_conv=True,
image_features=image_features) image_features=image_features)
return feature_maps.values() return list(feature_maps.values())
class FakeLSTMInterleavedFeatureExtractor( class FakeLSTMInterleavedFeatureExtractor(
...@@ -141,7 +141,7 @@ class FakeLSTMInterleavedFeatureExtractor( ...@@ -141,7 +141,7 @@ class FakeLSTMInterleavedFeatureExtractor(
min_depth=self._min_depth, min_depth=self._min_depth,
insert_1x1_conv=True, insert_1x1_conv=True,
image_features=image_features) image_features=image_features)
return feature_maps.values() return list(feature_maps.values())
class MockAnchorGenerator2x2(anchor_generator.AnchorGenerator): class MockAnchorGenerator2x2(anchor_generator.AnchorGenerator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment