Commit f5e6e291 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 289743636
parent 242ad38d
......@@ -43,10 +43,11 @@ class AlbertTransformerEncoder(network.Network):
Attributes:
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. Embedding parameters will
be factorized into two matrices in the shape of ['vocab_size',
'embedding_width'] and ['embedding_width', 'hidden_size']
('embedding_width' is usually much smaller than 'hidden_size').
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into
two matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
......@@ -149,9 +150,13 @@ class AlbertTransformerEncoder(network.Network):
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate,
dtype=tf.float32)(embeddings))
# The width of final 'embedding' should be always 'hidden_size'.
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
embeddings = layers.DenseEinsum(
output_shape=hidden_size, name='embedding_projection')(
output_shape=hidden_size,
kernel_initializer=initializer,
name='embedding_projection')(
embeddings)
if float_dtype == 'float16':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment