Commit a5e4965d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 449555522
parent d15302ff
...@@ -13,10 +13,9 @@ ...@@ -13,10 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains custom quantization layer transforms.""" """Contains custom quantization layer transforms."""
from typing import Type, Mapping from typing import Any, Type, Mapping, List, Union, Tuple
import tensorflow as tf import tensorflow as tf
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks
from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers
...@@ -58,17 +57,23 @@ class CustomLayerQuantize( ...@@ -58,17 +57,23 @@ class CustomLayerQuantize(
} }
return layer_metadata return layer_metadata
def _create_dummy_input_shape(
self, quantized_layer: tf.keras.layers.Layer
) -> Union[List[int], Tuple[Any, Any]]:
dummy_input_shape = [1, 128, 128, 1]
# SegmentationHead layer requires a tuple of 2 tensors.
if isinstance(quantized_layer,
quantized_nn_layers.SegmentationHeadQuantized):
dummy_input_shape = ([1, 1, 1, 1], [1, 1, 1, 1])
return dummy_input_shape
def replacement(self, match_layer: LayerNode) -> LayerNode: def replacement(self, match_layer: LayerNode) -> LayerNode:
"""See base class.""" """See base class."""
bottleneck_layer = match_layer.layer bottleneck_layer = match_layer.layer
bottleneck_config = bottleneck_layer['config'] bottleneck_config = bottleneck_layer['config']
bottleneck_names_and_weights = list(match_layer.names_and_weights) bottleneck_names_and_weights = list(match_layer.names_and_weights)
quantized_layer = self._quantized_layer_class(**bottleneck_config) quantized_layer = self._quantized_layer_class(**bottleneck_config)
dummy_input_shape = [1, 64, 128, 1] dummy_input_shape = self._create_dummy_input_shape(quantized_layer)
# SegmentationHead layer requires a tuple of 2 tensors.
if isinstance(quantized_layer,
quantized_nn_layers.SegmentationHeadQuantized):
dummy_input_shape = ([1, 1, 1, 1], [1, 1, 1, 1])
quantized_layer.compute_output_shape(dummy_input_shape) quantized_layer.compute_output_shape(dummy_input_shape)
quantized_names_and_weights = zip( quantized_names_and_weights = zip(
[weight.name for weight in quantized_layer.weights], [weight.name for weight in quantized_layer.weights],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment