Commit bd4cb317 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 451485251
parent b1157cf4
...@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union, Tuple ...@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -138,20 +139,20 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -138,20 +139,20 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
self._embedding_layer = embedding_layer self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
...@@ -169,7 +170,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -169,7 +170,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
# The first 999 tokens are special tokens such as [PAD], [CLS], [SEP]. # The first 999 tokens are special tokens such as [PAD], [CLS], [SEP].
...@@ -204,14 +205,14 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -204,14 +205,14 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
self._config = { self._config = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment