Commit 4ea11215 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

aliasing OnDeviceEmbedding inside tensorflow_models.

PiperOrigin-RevId: 331173006
parent a298202e
...@@ -26,7 +26,7 @@ import tensorflow as tf ...@@ -26,7 +26,7 @@ import tensorflow as tf
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp import keras_nlp
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -137,10 +137,11 @@ ENCODER_CLS = { ...@@ -137,10 +137,11 @@ ENCODER_CLS = {
@gin.configurable @gin.configurable
def build_encoder(config: EncoderConfig, def build_encoder(
embedding_layer: Optional[layers.OnDeviceEmbedding] = None, config: EncoderConfig,
encoder_cls=None, embedding_layer: Optional[keras_nlp.layers.OnDeviceEmbedding] = None,
bypass_config: bool = False): encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig. """Instantiate a Transformer encoder network from EncoderConfig.
Args: Args:
......
...@@ -34,9 +34,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -34,9 +34,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory. will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that scale_factor: Whether to scale the output embeddings. Defaults to None (that
is, not to scale). Setting this option to True will let values in output is, not to scale). Setting this option to a float will let values in
embeddings multiplied by self._embedding_width ** 0.5. output embeddings multiplied by scale_factor.
""" """
def __init__(self, def __init__(self,
...@@ -44,7 +44,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -44,7 +44,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width, embedding_width,
initializer="glorot_uniform", initializer="glorot_uniform",
use_one_hot=False, use_one_hot=False,
use_scale=False, scale_factor=None,
**kwargs): **kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs) super(OnDeviceEmbedding, self).__init__(**kwargs)
...@@ -52,7 +52,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -52,7 +52,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width self._embedding_width = embedding_width
self._initializer = initializer self._initializer = initializer
self._use_one_hot = use_one_hot self._use_one_hot = use_one_hot
self._use_scale = use_scale self._scale_factor = scale_factor
def get_config(self): def get_config(self):
config = { config = {
...@@ -60,7 +60,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -60,7 +60,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width, "embedding_width": self._embedding_width,
"initializer": self._initializer, "initializer": self._initializer,
"use_one_hot": self._use_one_hot, "use_one_hot": self._use_one_hot,
"use_scale": self._use_scale, "scale_factor": self._scale_factor,
} }
base_config = super(OnDeviceEmbedding, self).get_config() base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -87,6 +87,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -87,6 +87,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list. # Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0)) tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width]) embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale: if self._scale_factor:
embeddings *= self._embedding_width**0.5 embeddings *= self._scale_factor
return embeddings return embeddings
...@@ -192,7 +192,8 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -192,7 +192,8 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
vocab_size = 31 vocab_size = 31
embedding_width = 27 embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding( test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True) vocab_size=vocab_size, embedding_width=embedding_width,
scale_factor=embedding_width**0.5)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23 sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32) input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
......
...@@ -142,12 +142,12 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -142,12 +142,12 @@ class Seq2SeqTransformer(tf.keras.Model):
self._beam_size = beam_size self._beam_size = beam_size
self._alpha = alpha self._alpha = alpha
self._dtype = dtype self._dtype = dtype
self.embedding_lookup = layers.OnDeviceEmbedding( self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size, vocab_size=self._vocab_size,
embedding_width=self._embedding_width, embedding_width=self._embedding_width,
initializer=tf.random_normal_initializer( initializer=tf.random_normal_initializer(
mean=0., stddev=self._embedding_width**-0.5), mean=0., stddev=self._embedding_width**-0.5),
use_scale=True) scale_factor=self._embedding_width**0.5)
self.encoder_layer = encoder_layer self.encoder_layer = encoder_layer
self.decoder_layer = decoder_layer self.decoder_layer = decoder_layer
self.position_embedding = layers.RelativePositionEmbedding( self.position_embedding = layers.RelativePositionEmbedding(
...@@ -472,7 +472,7 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -472,7 +472,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
self.encoder_layers = [] self.encoder_layers = []
for i in range(self.num_layers): for i in range(self.num_layers):
self.encoder_layers.append( self.encoder_layers.append(
keras_nlp.TransformerEncoderBlock( keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
inner_dim=self._intermediate_size, inner_dim=self._intermediate_size,
inner_activation=self._activation, inner_activation=self._activation,
......
...@@ -141,7 +141,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -141,7 +141,7 @@ class EncoderScaffold(tf.keras.Model):
shape=(seq_length,), dtype=tf.int32, name='input_type_ids') shape=(seq_length,), dtype=tf.int32, name='input_type_ids')
inputs = [word_ids, mask, type_ids] inputs = [word_ids, mask, type_ids]
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'], vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=embedding_cfg['initializer'],
...@@ -150,13 +150,13 @@ class EncoderScaffold(tf.keras.Model): ...@@ -150,13 +150,13 @@ class EncoderScaffold(tf.keras.Model):
word_embeddings = self._embedding_layer(word_ids) word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.PositionEmbedding( self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=embedding_cfg['initializer'], initializer=embedding_cfg['initializer'],
max_length=embedding_cfg['max_seq_length'], max_length=embedding_cfg['max_seq_length'],
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'], vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=embedding_cfg['initializer'],
......
...@@ -101,18 +101,18 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -101,18 +101,18 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.max_sequence_length = max_sequence_length self.max_sequence_length = max_sequence_length
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.word_embedding = layers.OnDeviceEmbedding( self.word_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.word_vocab_size, self.word_vocab_size,
self.word_embed_size, self.word_embed_size,
initializer=initializer, initializer=initializer,
name='word_embedding') name='word_embedding')
self.type_embedding = layers.OnDeviceEmbedding( self.type_embedding = keras_nlp.layers.OnDeviceEmbedding(
self.type_vocab_size, self.type_vocab_size,
self.output_embed_size, self.output_embed_size,
use_one_hot=True, use_one_hot=True,
initializer=initializer, initializer=initializer,
name='type_embedding') name='type_embedding')
self.pos_embedding = keras_nlp.PositionEmbedding( self.pos_embedding = keras_nlp.layers.PositionEmbedding(
max_length=max_sequence_length, max_length=max_sequence_length,
initializer=initializer, initializer=initializer,
name='position_embedding') name='position_embedding')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment