Commit b3377b09 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

In BERT's export to TF Hub, fix shape propagation for seq_length.

PiperOrigin-RevId: 307425903
parent 0b0ca66b
......@@ -32,9 +32,10 @@ class ExportTfhubTest(tf.test.TestCase):
def test_export_tfhub(self):
# Exports a savedmodel for TF-Hub
hidden_size = 16
bert_config = configs.BertConfig(
vocab_size=100,
hidden_size=16,
hidden_size=hidden_size,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
......@@ -67,7 +68,8 @@ class ExportTfhubTest(tf.test.TestCase):
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
seq_length = 10
dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
......@@ -75,13 +77,23 @@ class ExportTfhubTest(tf.test.TestCase):
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
# Test propagation of seq_length in shape inference.
input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
input_type_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
pooled_output, sequence_output = hub_layer(
[input_word_ids, input_mask, input_type_ids])
self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
self.assertEqual(sequence_output.shape.as_list(),
[None, seq_length, hidden_size])
if __name__ == "__main__":
tf.test.main()
......@@ -21,8 +21,6 @@ from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text")
class OnDeviceEmbedding(tf.keras.layers.Layer):
......@@ -78,8 +76,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
super(OnDeviceEmbedding, self).build(input_shape)
def call(self, inputs):
input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
input_shape.append(self._embedding_width)
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
one_hot_data = tf.one_hot(
......@@ -87,6 +83,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embeddings = tf.matmul(one_hot_data, self.embeddings)
else:
embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(embeddings, input_shape)
embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
return embeddings
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment