Commit 57253ebc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Modify the unit test of encoder scaffold to demonstrate that a subclass...

Modify the unit test of encoder scaffold to demonstrate that a subclass embedding network can be used.

PiperOrigin-RevId: 322083775
parent 57c08a08
......@@ -323,6 +323,28 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
self.assertAllEqual(network.get_config(), new_network.get_config())
class Embeddings(tf.keras.Model):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.inputs = [
tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids"),
tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name="input_mask")
]
self.attention_mask = layers.SelfAttentionMask()
self.embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
def call(self, inputs):
word_ids, mask = inputs
word_embeddings = self.embedding_layer(word_ids)
return word_embeddings, self.attention_mask([word_embeddings, mask])
@keras_parameterized.run_all_keras_modes
class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
......@@ -334,20 +356,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
# Build an embedding network to swap in for the default network. This one
# will have 2 inputs (mask and word_ids) instead of 3, and won't use
# positional embeddings.
word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name="input_word_ids")
mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name="input_mask")
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
name="word_embeddings")
word_embeddings = embedding_layer(word_ids)
attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])
network = tf.keras.Model([word_ids, mask],
[word_embeddings, attention_mask])
network = Embeddings(vocab_size, hidden_size)
hidden_cfg = {
"num_attention_heads":
......@@ -371,8 +380,7 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02),
hidden_cfg=hidden_cfg,
embedding_cls=network,
embedding_data=embedding_layer.embeddings)
embedding_cls=network)
# Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -390,11 +398,6 @@ class EncoderScaffoldEmbeddingNetworkTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data])
# Test that we can get the embedding data that we passed to the object. This
# is necessary to support standard language model training.
self.assertIs(embedding_layer.embeddings,
test_network.get_embedding_table())
def test_serialize_deserialize(self):
hidden_size = 32
sequence_length = 21
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment