Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
41992cd2
Commit
41992cd2
authored
Sep 09, 2020
by
Zhenyu Tan
Committed by
A. Unique TensorFlower
Sep 09, 2020
Browse files
Move OnDeviceEmbedding to keras_nlp.
PiperOrigin-RevId: 330754739
parent
fffea332
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
74 deletions
+96
-74
official/nlp/keras_nlp/layers/__init__.py
official/nlp/keras_nlp/layers/__init__.py
+1
-0
official/nlp/keras_nlp/layers/on_device_embedding.py
official/nlp/keras_nlp/layers/on_device_embedding.py
+92
-0
official/nlp/keras_nlp/layers/on_device_embedding_test.py
official/nlp/keras_nlp/layers/on_device_embedding_test.py
+1
-1
official/nlp/modeling/layers/on_device_embedding.py
official/nlp/modeling/layers/on_device_embedding.py
+2
-73
No files found.
official/nlp/keras_nlp/layers/__init__.py
View file @
41992cd2
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Keras-NLP layers package definition."""
from
official.nlp.keras_nlp.layers.on_device_embedding
import
OnDeviceEmbedding
from
official.nlp.keras_nlp.layers.position_embedding
import
PositionEmbedding
from
official.nlp.keras_nlp.layers.self_attention_mask
import
SelfAttentionMask
from
official.nlp.keras_nlp.layers.transformer_encoder_block
import
TransformerEncoderBlock
official/nlp/keras_nlp/layers/on_device_embedding.py
0 → 100644
View file @
41992cd2
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"keras_nlp"
)
class
OnDeviceEmbedding
(
tf
.
keras
.
layers
.
Layer
):
"""Performs an embedding lookup suitable for accelerator devices.
This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings.
Arguments:
vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def
__init__
(
self
,
vocab_size
,
embedding_width
,
initializer
=
"glorot_uniform"
,
use_one_hot
=
False
,
use_scale
=
False
,
**
kwargs
):
super
(
OnDeviceEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_embedding_width
=
embedding_width
self
.
_initializer
=
initializer
self
.
_use_one_hot
=
use_one_hot
self
.
_use_scale
=
use_scale
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"embedding_width"
:
self
.
_embedding_width
,
"initializer"
:
self
.
_initializer
,
"use_one_hot"
:
self
.
_use_one_hot
,
"use_scale"
:
self
.
_use_scale
,
}
base_config
=
super
(
OnDeviceEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
self
.
embeddings
=
self
.
add_weight
(
"embeddings"
,
shape
=
[
self
.
_vocab_size
,
self
.
_embedding_width
],
initializer
=
self
.
_initializer
,
dtype
=
tf
.
float32
)
super
(
OnDeviceEmbedding
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
):
flat_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
])
if
self
.
_use_one_hot
:
one_hot_data
=
tf
.
one_hot
(
flat_inputs
,
depth
=
self
.
_vocab_size
,
dtype
=
self
.
embeddings
.
dtype
)
embeddings
=
tf
.
matmul
(
one_hot_data
,
self
.
embeddings
)
else
:
embeddings
=
tf
.
gather
(
self
.
embeddings
,
flat_inputs
)
embeddings
=
tf
.
reshape
(
embeddings
,
# Work around b/142213824: prefer concat to shape over a Python list.
tf
.
concat
([
tf
.
shape
(
inputs
),
[
self
.
_embedding_width
]],
axis
=
0
))
embeddings
.
set_shape
(
inputs
.
shape
.
as_list
()
+
[
self
.
_embedding_width
])
if
self
.
_use_scale
:
embeddings
*=
self
.
_embedding_width
**
0.5
return
embeddings
official/nlp/
modeling
/layers/on_device_embedding_test.py
→
official/nlp/
keras_nlp
/layers/on_device_embedding_test.py
View file @
41992cd2
...
...
@@ -18,7 +18,7 @@ import numpy as np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.
modeling
.layers
import
on_device_embedding
from
official.nlp.
keras_nlp
.layers
import
on_device_embedding
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
...
...
official/nlp/modeling/layers/on_device_embedding.py
View file @
41992cd2
...
...
@@ -15,78 +15,7 @@
"""Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.nlp
import
keras_nlp
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
OnDeviceEmbedding
(
tf
.
keras
.
layers
.
Layer
):
"""Performs an embedding lookup suitable for accelerator devices.
This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings.
Arguments:
vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def
__init__
(
self
,
vocab_size
,
embedding_width
,
initializer
=
"glorot_uniform"
,
use_one_hot
=
False
,
use_scale
=
False
,
**
kwargs
):
super
(
OnDeviceEmbedding
,
self
).
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_embedding_width
=
embedding_width
self
.
_initializer
=
initializer
self
.
_use_one_hot
=
use_one_hot
self
.
_use_scale
=
use_scale
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"embedding_width"
:
self
.
_embedding_width
,
"initializer"
:
self
.
_initializer
,
"use_one_hot"
:
self
.
_use_one_hot
,
"use_scale"
:
self
.
_use_scale
,
}
base_config
=
super
(
OnDeviceEmbedding
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
self
.
embeddings
=
self
.
add_weight
(
"embeddings"
,
shape
=
[
self
.
_vocab_size
,
self
.
_embedding_width
],
initializer
=
self
.
_initializer
,
dtype
=
tf
.
float32
)
super
(
OnDeviceEmbedding
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
):
flat_inputs
=
tf
.
reshape
(
inputs
,
[
-
1
])
if
self
.
_use_one_hot
:
one_hot_data
=
tf
.
one_hot
(
flat_inputs
,
depth
=
self
.
_vocab_size
,
dtype
=
self
.
embeddings
.
dtype
)
embeddings
=
tf
.
matmul
(
one_hot_data
,
self
.
embeddings
)
else
:
embeddings
=
tf
.
gather
(
self
.
embeddings
,
flat_inputs
)
embeddings
=
tf
.
reshape
(
embeddings
,
# Work around b/142213824: prefer concat to shape over a Python list.
tf
.
concat
([
tf
.
shape
(
inputs
),
[
self
.
_embedding_width
]],
axis
=
0
))
embeddings
.
set_shape
(
inputs
.
shape
.
as_list
()
+
[
self
.
_embedding_width
])
if
self
.
_use_scale
:
embeddings
*=
self
.
_embedding_width
**
0.5
return
embeddings
OnDeviceEmbedding
=
keras_nlp
.
layers
.
OnDeviceEmbedding
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment