Commit fd270433 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 346088379
parent 8ffa9448
...@@ -35,7 +35,8 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -35,7 +35,8 @@ class PackedSequenceEmbedding(tf.keras.Model):
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
type_vocab_size: The size of the type vocabulary. type_vocab_size: The size of the type vocabulary.
hidden_size: The hidden size for this encoder. embedding_width: Width of token embeddings.
hidden_size: The output size for this encoder.
max_seq_length: The maximum sequence length for this encoder. max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder. initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers. dropout_rate: The dropout rate to apply before the encoding layers.
...@@ -52,6 +53,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -52,6 +53,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
def __init__(self, def __init__(self,
vocab_size, vocab_size,
type_vocab_size, type_vocab_size,
embedding_width,
hidden_size, hidden_size,
max_seq_length, max_seq_length,
initializer, initializer,
...@@ -63,6 +65,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -63,6 +65,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
config_dict = { config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'max_seq_length': max_seq_length, 'max_seq_length': max_seq_length,
'initializer': tf.keras.initializers.serialize(initializer), 'initializer': tf.keras.initializers.serialize(initializer),
...@@ -96,7 +99,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -96,7 +99,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=hidden_size, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
...@@ -113,7 +116,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -113,7 +116,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
type_embeddings = ( type_embeddings = (
layers.OnDeviceEmbedding( layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=hidden_size, embedding_width=embedding_width,
initializer=initializer, initializer=initializer,
use_one_hot=True, use_one_hot=True,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
...@@ -127,6 +130,15 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -127,6 +130,15 @@ class PackedSequenceEmbedding(tf.keras.Model):
rate=dropout_rate, dtype=tf.float32)( rate=dropout_rate, dtype=tf.float32)(
embeddings) embeddings)
if embedding_width != hidden_size:
embeddings = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes=None,
kernel_initializer=initializer,
name='embedding_projection')(
embeddings)
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = layers.SelfAttentionMask()([embeddings, mask])
if sub_seq_mask is not None: if sub_seq_mask is not None:
attention_mask = tf.keras.layers.Lambda( attention_mask = tf.keras.layers.Lambda(
......
...@@ -45,10 +45,12 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -45,10 +45,12 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
vocab_size = 100 vocab_size = 100
max_position_embeddings = 32 max_position_embeddings = 32
type_vocab_size = 2 type_vocab_size = 2
embedding_width = 16
hidden_size = 32 hidden_size = 32
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=vocab_size, vocab_size=vocab_size,
type_vocab_size=2, type_vocab_size=2,
embedding_width=embedding_width,
hidden_size=hidden_size, hidden_size=hidden_size,
max_seq_length=max_position_embeddings, max_seq_length=max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
...@@ -103,6 +105,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase): ...@@ -103,6 +105,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=100, vocab_size=100,
type_vocab_size=2, type_vocab_size=2,
embedding_width=64,
hidden_size=64, hidden_size=64,
max_seq_length=32, max_seq_length=32,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment