"sgl-router/src/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f5d30dae89fd413cabd2d573c2eed9907d233dcb"
Commit 41992cd2 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Move OnDeviceEmbedding to keras_nlp.

PiperOrigin-RevId: 330754739
parent fffea332
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-NLP layers package definition.""" """Keras-NLP layers package definition."""
from official.nlp.keras_nlp.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.keras_nlp.layers.position_embedding import PositionEmbedding from official.nlp.keras_nlp.layers.position_embedding import PositionEmbedding
from official.nlp.keras_nlp.layers.self_attention_mask import SelfAttentionMask from official.nlp.keras_nlp.layers.self_attention_mask import SelfAttentionMask
from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package="keras_nlp")
class OnDeviceEmbedding(tf.keras.layers.Layer):
"""Performs an embedding lookup suitable for accelerator devices.
This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings.
Arguments:
vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def __init__(self,
vocab_size,
embedding_width,
initializer="glorot_uniform",
use_one_hot=False,
use_scale=False,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._initializer = initializer
self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"embedding_width": self._embedding_width,
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
}
base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self.embeddings = self.add_weight(
"embeddings",
shape=[self._vocab_size, self._embedding_width],
initializer=self._initializer,
dtype=tf.float32)
super(OnDeviceEmbedding, self).build(input_shape)
def call(self, inputs):
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings)
else:
embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width**0.5
return embeddings
...@@ -18,7 +18,7 @@ import numpy as np ...@@ -18,7 +18,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import on_device_embedding from official.nlp.keras_nlp.layers import on_device_embedding
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
...@@ -15,78 +15,7 @@ ...@@ -15,78 +15,7 @@
"""Keras-based one-hot embedding layer.""" """Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import tensorflow as tf from official.nlp import keras_nlp
@tf.keras.utils.register_keras_serializable(package="Text") OnDeviceEmbedding = keras_nlp.layers.OnDeviceEmbedding
class OnDeviceEmbedding(tf.keras.layers.Layer):
"""Performs an embedding lookup suitable for accelerator devices.
This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings.
Arguments:
vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def __init__(self,
vocab_size,
embedding_width,
initializer="glorot_uniform",
use_one_hot=False,
use_scale=False,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
self._vocab_size = vocab_size
self._embedding_width = embedding_width
self._initializer = initializer
self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self):
config = {
"vocab_size": self._vocab_size,
"embedding_width": self._embedding_width,
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
}
base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self.embeddings = self.add_weight(
"embeddings",
shape=[self._vocab_size, self._embedding_width],
initializer=self._initializer,
dtype=tf.float32)
super(OnDeviceEmbedding, self).build(input_shape)
def call(self, inputs):
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=self.embeddings.dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings)
else:
embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width**0.5
return embeddings
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment