"vscode:/vscode.git/clone" did not exist on "a9c1e3a31dd7025e3e4dab9003f091424706d770"
Commit a5e4965d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 449555522
parent d15302ff
......@@ -13,10 +13,9 @@
# limitations under the License.
"""Contains custom quantization layer transforms."""
from typing import Type, Mapping
from typing import Any, Type, Mapping, List, Union, Tuple
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.projects.qat.vision.modeling.layers import nn_blocks as quantized_nn_blocks
from official.projects.qat.vision.modeling.layers import nn_layers as quantized_nn_layers
......@@ -58,17 +57,23 @@ class CustomLayerQuantize(
}
return layer_metadata
def _create_dummy_input_shape(
self, quantized_layer: tf.keras.layers.Layer
) -> Union[List[int], Tuple[Any, Any]]:
dummy_input_shape = [1, 128, 128, 1]
# SegmentationHead layer requires a tuple of 2 tensors.
if isinstance(quantized_layer,
quantized_nn_layers.SegmentationHeadQuantized):
dummy_input_shape = ([1, 1, 1, 1], [1, 1, 1, 1])
return dummy_input_shape
def replacement(self, match_layer: LayerNode) -> LayerNode:
"""See base class."""
bottleneck_layer = match_layer.layer
bottleneck_config = bottleneck_layer['config']
bottleneck_names_and_weights = list(match_layer.names_and_weights)
quantized_layer = self._quantized_layer_class(**bottleneck_config)
dummy_input_shape = [1, 64, 128, 1]
# SegmentationHead layer requires a tuple of 2 tensors.
if isinstance(quantized_layer,
quantized_nn_layers.SegmentationHeadQuantized):
dummy_input_shape = ([1, 1, 1, 1], [1, 1, 1, 1])
dummy_input_shape = self._create_dummy_input_shape(quantized_layer)
quantized_layer.compute_output_shape(dummy_input_shape)
quantized_names_and_weights = zip(
[weight.name for weight in quantized_layer.weights],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment