Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4a34c084
Commit
4a34c084
authored
Aug 25, 2022
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 25, 2022
Browse files
Internal change
PiperOrigin-RevId: 470043432
parent
44feee08
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
200 additions
and
1 deletion
+200
-1
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+49
-0
official/nlp/modeling/layers/__init__.py
official/nlp/modeling/layers/__init__.py
+1
-0
official/nlp/modeling/layers/factorized_embedding.py
official/nlp/modeling/layers/factorized_embedding.py
+76
-0
official/nlp/modeling/layers/factorized_embedding_test.py
official/nlp/modeling/layers/factorized_embedding_test.py
+70
-0
official/nlp/tools/export_tfhub_lib.py
official/nlp/tools/export_tfhub_lib.py
+4
-1
No files found.
official/nlp/configs/encoders.py
View file @
4a34c084
...
@@ -221,6 +221,27 @@ class XLNetEncoderConfig(hyperparams.Config):
...
@@ -221,6 +221,27 @@ class XLNetEncoderConfig(hyperparams.Config):
two_stream
:
bool
=
False
two_stream
:
bool
=
False
@
dataclasses
.
dataclass
class
QueryBertConfig
(
hyperparams
.
Config
):
"""Query BERT encoder configuration."""
vocab_size
:
int
=
30522
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
hidden_activation
:
str
=
"gelu"
intermediate_size
:
int
=
3072
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
output_range
:
Optional
[
int
]
=
None
return_all_encoder_outputs
:
bool
=
False
# Pre/Post-LN Transformer
norm_first
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
"""Encoder configuration."""
...
@@ -233,6 +254,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
...
@@ -233,6 +254,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
reuse
:
ReuseEncoderConfig
=
ReuseEncoderConfig
()
reuse
:
ReuseEncoderConfig
=
ReuseEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
query_bert
:
QueryBertConfig
=
QueryBertConfig
()
# If `any` is used, the encoder building relies on any.BUILDER.
# If `any` is used, the encoder building relies on any.BUILDER.
any
:
hyperparams
.
Config
=
hyperparams
.
Config
()
any
:
hyperparams
.
Config
=
hyperparams
.
Config
()
...
@@ -513,6 +535,33 @@ def build_encoder(config: EncoderConfig,
...
@@ -513,6 +535,33 @@ def build_encoder(config: EncoderConfig,
recursive
=
True
)
recursive
=
True
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
return
networks
.
EncoderScaffold
(
**
kwargs
)
if
encoder_type
==
"query_bert"
:
embedding_layer
=
layers
.
FactorizedEmbedding
(
vocab_size
=
encoder_cfg
.
vocab_size
,
embedding_width
=
encoder_cfg
.
embedding_size
,
output_dim
=
encoder_cfg
.
hidden_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
name
=
"word_embeddings"
)
return
networks
.
BertEncoderV2
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_layer
=
embedding_layer
,
return_all_encoder_outputs
=
encoder_cfg
.
return_all_encoder_outputs
,
dict_outputs
=
True
,
norm_first
=
encoder_cfg
.
norm_first
)
bert_encoder_cls
=
networks
.
BertEncoder
bert_encoder_cls
=
networks
.
BertEncoder
if
encoder_type
==
"bert_v2"
:
if
encoder_type
==
"bert_v2"
:
bert_encoder_cls
=
networks
.
BertEncoderV2
bert_encoder_cls
=
networks
.
BertEncoderV2
...
...
official/nlp/modeling/layers/__init__.py
View file @
4a34c084
...
@@ -22,6 +22,7 @@ from official.nlp.modeling.layers.bigbird_attention import BigBirdAttention
...
@@ -22,6 +22,7 @@ from official.nlp.modeling.layers.bigbird_attention import BigBirdAttention
from
official.nlp.modeling.layers.bigbird_attention
import
BigBirdMasks
from
official.nlp.modeling.layers.bigbird_attention
import
BigBirdMasks
from
official.nlp.modeling.layers.block_diag_feedforward
import
BlockDiagFeedforward
from
official.nlp.modeling.layers.block_diag_feedforward
import
BlockDiagFeedforward
from
official.nlp.modeling.layers.cls_head
import
*
from
official.nlp.modeling.layers.cls_head
import
*
from
official.nlp.modeling.layers.factorized_embedding
import
FactorizedEmbedding
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.gated_feedforward
import
GatedFeedforward
from
official.nlp.modeling.layers.gaussian_process
import
RandomFeatureGaussianProcess
from
official.nlp.modeling.layers.gaussian_process
import
RandomFeatureGaussianProcess
from
official.nlp.modeling.layers.kernel_attention
import
KernelAttention
from
official.nlp.modeling.layers.kernel_attention
import
KernelAttention
...
...
official/nlp/modeling/layers/factorized_embedding.py
0 → 100644
View file @
4a34c084
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A factorized embedding layer."""
# pylint: disable=g-classes-have-attributes
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
on_device_embedding
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
FactorizedEmbedding
(
on_device_embedding
.
OnDeviceEmbedding
):
"""A factorized embeddings layer for supporting larger embeddings.
Arguments:
vocab_size: Number of elements in the vocabulary.
embedding_width: Width of word embeddings.
output_dim: The output dimension of this layer.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
use_one_hot: Whether to use tf.one_hot over tf.gather for the embedding
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
scale_factor: Whether to scale the output embeddings. Defaults to None (that
is, not to scale). Setting this option to a float will let values in
output embeddings multiplied by scale_factor.
"""
def
__init__
(
self
,
vocab_size
:
int
,
embedding_width
:
int
,
output_dim
:
int
,
initializer
=
'glorot_uniform'
,
use_one_hot
=
False
,
scale_factor
=
None
,
**
kwargs
):
super
().
__init__
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
initializer
=
initializer
,
use_one_hot
=
use_one_hot
,
scale_factor
=
scale_factor
,
**
kwargs
)
self
.
_output_dim
=
output_dim
def
get_config
(
self
):
config
=
{
'output_dim'
:
self
.
_output_dim
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
self
.
_embedding_projection
=
tf
.
keras
.
layers
.
EinsumDense
(
'...x,xy->...y'
,
output_shape
=
self
.
_output_dim
,
bias_axes
=
None
,
kernel_initializer
=
tf_utils
.
clone_initializer
(
self
.
_initializer
),
name
=
'embedding_projection'
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
):
output
=
super
().
call
(
inputs
)
return
self
.
_embedding_projection
(
output
)
official/nlp/modeling/layers/factorized_embedding_test.py
0 → 100644
View file @
4a34c084
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for FactorizedEmbedding layer."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
factorized_embedding
class
FactorizedEmbeddingTest
(
tf
.
test
.
TestCase
):
def
test_layer_creation
(
self
):
vocab_size
=
31
embedding_width
=
27
output_dim
=
45
test_layer
=
factorized_embedding
.
FactorizedEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
output_dim
=
output_dim
)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length
=
23
input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
),
dtype
=
tf
.
int32
)
output_tensor
=
test_layer
(
input_tensor
)
# The output should be the same as the input, save that it has an extra
# embedding_width dimension on the end.
expected_output_shape
=
[
None
,
sequence_length
,
output_dim
]
self
.
assertEqual
(
expected_output_shape
,
output_tensor
.
shape
.
as_list
())
self
.
assertEqual
(
output_tensor
.
dtype
,
tf
.
float32
)
def
test_layer_invocation
(
self
):
vocab_size
=
31
embedding_width
=
27
output_dim
=
45
test_layer
=
factorized_embedding
.
FactorizedEmbedding
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
output_dim
=
output_dim
)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length
=
23
input_tensor
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
),
dtype
=
tf
.
int32
)
output_tensor
=
test_layer
(
input_tensor
)
# Create a model from the test layer.
model
=
tf
.
keras
.
Model
(
input_tensor
,
output_tensor
)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size
=
3
input_data
=
np
.
random
.
randint
(
vocab_size
,
size
=
(
batch_size
,
sequence_length
))
output
=
model
.
predict
(
input_data
)
self
.
assertEqual
(
tf
.
float32
,
output
.
dtype
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/tools/export_tfhub_lib.py
View file @
4a34c084
...
@@ -133,7 +133,10 @@ def _create_model(
...
@@ -133,7 +133,10 @@ def _create_model(
encoder_network
=
encoder
,
encoder_network
=
encoder
,
mlm_activation
=
tf_utils
.
get_activation
(
hidden_act
))
mlm_activation
=
tf_utils
.
get_activation
(
hidden_act
))
pretrainer_inputs_dict
=
{
x
.
name
:
x
for
x
in
pretrainer
.
inputs
}
if
isinstance
(
pretrainer
.
inputs
,
dict
):
pretrainer_inputs_dict
=
pretrainer
.
inputs
else
:
pretrainer_inputs_dict
=
{
x
.
name
:
x
for
x
in
pretrainer
.
inputs
}
pretrainer_output_dict
=
pretrainer
(
pretrainer_inputs_dict
)
pretrainer_output_dict
=
pretrainer
(
pretrainer_inputs_dict
)
mlm_model
=
tf
.
keras
.
Model
(
mlm_model
=
tf
.
keras
.
Model
(
inputs
=
pretrainer_inputs_dict
,
outputs
=
pretrainer_output_dict
)
inputs
=
pretrainer_inputs_dict
,
outputs
=
pretrainer_output_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment