Commit 196f09ae authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 361482425
parent 3e45d52f
......@@ -102,6 +102,7 @@ class MobileBertEncoderConfig(hyperparams.Config):
num_feedforward_networks: int = 1
normalization_type: str = "layer_norm"
classifier_activation: bool = True
input_mask_dtype: str = "int32"
@dataclasses.dataclass
......@@ -260,7 +261,8 @@ def build_encoder(config: EncoderConfig,
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
num_feedforward_networks=encoder_cfg.num_feedforward_networks,
normalization_type=encoder_cfg.normalization_type,
classifier_activation=encoder_cfg.classifier_activation)
classifier_activation=encoder_cfg.classifier_activation,
input_mask_dtype=encoder_cfg.input_mask_dtype)
if encoder_type == "albert":
return encoder_cls(
......
......@@ -113,7 +113,6 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.type_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
use_one_hot=True,
initializer=initializer,
name='type_embedding')
self.pos_embedding = keras_nlp.layers.PositionEmbedding(
......
......@@ -43,6 +43,7 @@ class MobileBERTEncoder(tf.keras.Model):
num_feedforward_networks=4,
normalization_type='no_norm',
classifier_activation=False,
input_mask_dtype='int32',
**kwargs):
"""Class initialization.
......@@ -76,6 +77,11 @@ class MobileBERTEncoder(tf.keras.Model):
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask`
tensor with values in float32 to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments.
"""
self._self_setattr_tracking = False
......@@ -115,11 +121,14 @@ class MobileBERTEncoder(tf.keras.Model):
input_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_word_ids')
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_mask')
shape=(None,), dtype=input_mask_dtype, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids]
attention_mask = keras_nlp.layers.SelfAttentionMask()(input_ids, input_mask)
# The dtype of `attention_mask` will the same as the dtype of `input_mask`.
attention_mask = keras_nlp.layers.SelfAttentionMask()(input_mask,
input_mask)
# build the computation graph
all_layer_outputs = []
......
......@@ -89,7 +89,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertIsInstance(all_layer_output, list)
self.assertLen(all_layer_output, num_blocks + 1)
def test_mobilebert_encoder_invocation(self):
@parameterized.parameters('int32', 'float32')
def test_mobilebert_encoder_invocation(self, input_mask_dtype):
vocab_size = 100
hidden_size = 32
sequence_length = 16
......@@ -97,10 +98,11 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
test_network = mobile_bert_encoder.MobileBERTEncoder(
word_vocab_size=vocab_size,
hidden_size=hidden_size,
num_blocks=num_blocks)
num_blocks=num_blocks,
input_mask_dtype=input_mask_dtype)
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=input_mask_dtype)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment