Commit 62487257 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add options to train a quantization friendly model.

PiperOrigin-RevId: 387434804
parent 697da657
...@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer): ...@@ -39,6 +39,23 @@ class NoNorm(tf.keras.layers.Layer):
return output return output
@tf.keras.utils.register_keras_serializable(package='Text')
class NoNormClipped(NoNorm):
"""Quantization friendly implementation for the NoNorm.
The output of NoNorm layer is clipped to [-6.0, 6.0] to make it quantization
friendly.
"""
def __init__(self, name=None):
super(NoNormClipped, self).__init__(name=name)
def call(self, feature):
output = feature * self.scale + self.bias
clipped_output = tf.clip_by_value(output, -6.0, 6.0)
return clipped_output
def _get_norm_layer(normalization_type='no_norm', name=None): def _get_norm_layer(normalization_type='no_norm', name=None):
"""Get normlization layer. """Get normlization layer.
...@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None): ...@@ -52,6 +69,8 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
""" """
if normalization_type == 'no_norm': if normalization_type == 'no_norm':
layer = NoNorm(name=name) layer = NoNorm(name=name)
elif normalization_type == 'no_norm_clipped':
layer = NoNormClipped(name=name)
elif normalization_type == 'layer_norm': elif normalization_type == 'layer_norm':
layer = tf.keras.layers.LayerNormalization( layer = tf.keras.layers.LayerNormalization(
name=name, name=name,
......
...@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0): ...@@ -33,6 +33,22 @@ def generate_fake_input(batch_size=1, seq_len=5, vocab_size=10000, seed=0):
return fake_input return fake_input
class EdgeTPUNoNormTest(tf.test.TestCase):
def test_no_norm(self):
layer = mobile_bert_layers.NoNormClipped()
feature = tf.random.uniform(
[2, 3, 4], minval=-8, maxval=8, dtype=tf.float32)
output = layer(feature)
output_shape = output.shape.as_list()
expected_shape = [2, 3, 4]
self.assertListEqual(output_shape, expected_shape, msg=None)
output_min = tf.reduce_min(output)
output_max = tf.reduce_max(output)
self.assertGreaterEqual(6.0, output_max)
self.assertLessEqual(-6.0, output_min)
class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
def test_embedding_layer_with_token_type(self): def test_embedding_layer_with_token_type(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment