"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "85a258235fe64d2d6c2b415bfe99de07d930d790"
Commit f5e6e291 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 289743636
parent 242ad38d
......@@ -43,10 +43,11 @@ class AlbertTransformerEncoder(network.Network):
Attributes:
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. Embedding parameters will
be factorized into two matrices in the shape of ['vocab_size',
'embedding_width'] and ['embedding_width', 'hidden_size']
('embedding_width' is usually much smaller than 'hidden_size').
embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into
two matrices in the shape of ['vocab_size', 'embedding_width'] and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much
smaller than 'hidden_size').
hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
......@@ -149,10 +150,14 @@ class AlbertTransformerEncoder(network.Network):
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate,
dtype=tf.float32)(embeddings))
# The width of final 'embedding' should be always 'hidden_size'.
embeddings = layers.DenseEinsum(
output_shape=hidden_size, name='embedding_projection')(
embeddings)
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
embeddings = layers.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=initializer,
name='embedding_projection')(
embeddings)
if float_dtype == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment