"docs/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "148a0856ab47d4b4a92894838a4525adaddaaf0b"
Commit fcff6f65 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

BertEncoder to ingest OnDeviceEmbedding from keras_nlp.

PiperOrigin-RevId: 330767398
parent ad65d85b
......@@ -19,7 +19,6 @@ import tensorflow as tf
from official.modeling import activations
from official.nlp import keras_nlp
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
......@@ -123,7 +122,7 @@ class BertEncoder(tf.keras.Model):
if embedding_width is None:
embedding_width = hidden_size
if embedding_layer is None:
self._embedding_layer = layers.OnDeviceEmbedding(
self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
......@@ -133,12 +132,12 @@ class BertEncoder(tf.keras.Model):
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.PositionEmbedding(
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding(
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
......@@ -168,7 +167,7 @@ class BertEncoder(tf.keras.Model):
self._transformer_layers = []
data = embeddings
attention_mask = keras_nlp.SelfAttentionMask()(data, mask)
attention_mask = keras_nlp.layers.SelfAttentionMask()(data, mask)
encoder_outputs = []
for i in range(num_layers):
if i == num_layers - 1 and output_range is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment