Commit fd270433 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 346088379
parent 8ffa9448
......@@ -35,7 +35,8 @@ class PackedSequenceEmbedding(tf.keras.Model):
Arguments:
vocab_size: The size of the token vocabulary.
type_vocab_size: The size of the type vocabulary.
hidden_size: The hidden size for this encoder.
embedding_width: Width of token embeddings.
hidden_size: The output size for this encoder.
max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers.
......@@ -52,6 +53,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
def __init__(self,
vocab_size,
type_vocab_size,
embedding_width,
hidden_size,
max_seq_length,
initializer,
......@@ -63,6 +65,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
config_dict = {
'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size,
'embedding_width': embedding_width,
'hidden_size': hidden_size,
'max_seq_length': max_seq_length,
'initializer': tf.keras.initializers.serialize(initializer),
......@@ -96,7 +99,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
embedding_width=embedding_width,
initializer=initializer,
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
......@@ -113,7 +116,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=hidden_size,
embedding_width=embedding_width,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')(type_ids))
......@@ -127,6 +130,15 @@ class PackedSequenceEmbedding(tf.keras.Model):
rate=dropout_rate, dtype=tf.float32)(
embeddings)
if embedding_width != hidden_size:
embeddings = tf.keras.layers.experimental.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes=None,
kernel_initializer=initializer,
name='embedding_projection')(
embeddings)
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
if sub_seq_mask is not None:
attention_mask = tf.keras.layers.Lambda(
......
......@@ -45,10 +45,12 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
vocab_size = 100
max_position_embeddings = 32
type_vocab_size = 2
embedding_width = 16
hidden_size = 32
embedding_cfg = dict(
vocab_size=vocab_size,
type_vocab_size=2,
embedding_width=embedding_width,
hidden_size=hidden_size,
max_seq_length=max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
......@@ -103,6 +105,7 @@ class PackedSequenceEmbeddingTest(tf.test.TestCase, parameterized.TestCase):
embedding_cfg = dict(
vocab_size=100,
type_vocab_size=2,
embedding_width=64,
hidden_size=64,
max_seq_length=32,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment