Commit b1eddf4f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Makes token embedding projection consistent between Albert and BERT

PiperOrigin-RevId: 312751112
parent 99b5438a
...@@ -64,7 +64,8 @@ def _create_bert_model(cfg): ...@@ -64,7 +64,8 @@ def _create_bert_model(cfg):
sequence_length=cfg.max_position_embeddings, sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size, type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range)) stddev=cfg.initializer_range),
embedding_width=cfg.embedding_size)
return bert_encoder return bert_encoder
......
...@@ -146,6 +146,15 @@ class TransformerEncoder(tf.keras.Model): ...@@ -146,6 +146,15 @@ class TransformerEncoder(tf.keras.Model):
embeddings = tf.keras.layers.Add()( embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings]) [word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
...@@ -156,14 +165,6 @@ class TransformerEncoder(tf.keras.Model): ...@@ -156,14 +165,6 @@ class TransformerEncoder(tf.keras.Model):
kernel_initializer=initializer, kernel_initializer=initializer,
name='embedding_projection') name='embedding_projection')
embeddings = self._embedding_projection(embeddings) embeddings = self._embedding_projection(embeddings)
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
self._transformer_layers = [] self._transformer_layers = []
data = embeddings data = embeddings
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment