Commit bd4cb317 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Prepare for upcoming keras initializer change.

PiperOrigin-RevId: 451485251
parent b1157cf4
......@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union, Tuple
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -138,20 +139,20 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
......@@ -169,7 +170,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
# The first 999 tokens are special tokens such as [PAD], [CLS], [SEP].
......@@ -204,14 +205,14 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
self._config = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment