Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bb124157
Commit
bb124157
authored
Mar 10, 2021
by
stephenwu
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into RTESuperGLUE
parents
2e9bb539
0edeb7f6
Changes
386
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
58 additions
and
50 deletions
+58
-50
official/nlp/modeling/models/dual_encoder_test.py
official/nlp/modeling/models/dual_encoder_test.py
+2
-2
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+2
-2
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+2
-2
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+2
-2
official/nlp/modeling/models/seq2seq_transformer_test.py
official/nlp/modeling/models/seq2seq_transformer_test.py
+2
-2
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+2
-2
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+2
-2
official/nlp/modeling/networks/__init__.py
official/nlp/modeling/networks/__init__.py
+2
-2
official/nlp/modeling/networks/albert_encoder.py
official/nlp/modeling/networks/albert_encoder.py
+2
-2
official/nlp/modeling/networks/albert_encoder_test.py
official/nlp/modeling/networks/albert_encoder_test.py
+2
-2
official/nlp/modeling/networks/bert_encoder.py
official/nlp/modeling/networks/bert_encoder.py
+3
-3
official/nlp/modeling/networks/bert_encoder_test.py
official/nlp/modeling/networks/bert_encoder_test.py
+2
-2
official/nlp/modeling/networks/classification.py
official/nlp/modeling/networks/classification.py
+2
-2
official/nlp/modeling/networks/classification_test.py
official/nlp/modeling/networks/classification_test.py
+2
-2
official/nlp/modeling/networks/encoder_scaffold.py
official/nlp/modeling/networks/encoder_scaffold.py
+2
-3
official/nlp/modeling/networks/encoder_scaffold_test.py
official/nlp/modeling/networks/encoder_scaffold_test.py
+2
-2
official/nlp/modeling/networks/mobile_bert_encoder.py
official/nlp/modeling/networks/mobile_bert_encoder.py
+13
-4
official/nlp/modeling/networks/mobile_bert_encoder_test.py
official/nlp/modeling/networks/mobile_bert_encoder_test.py
+8
-6
official/nlp/modeling/networks/packed_sequence_embedding.py
official/nlp/modeling/networks/packed_sequence_embedding.py
+2
-3
official/nlp/modeling/networks/packed_sequence_embedding_test.py
...l/nlp/modeling/networks/packed_sequence_embedding_test.py
+2
-3
No files found.
official/nlp/modeling/models/dual_encoder_test.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for dual encoder network."""
"""Tests for dual encoder network."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/nlp/modeling/models/electra_pretrainer.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Trainer network for ELECTRA models."""
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
...
...
official/nlp/modeling/models/electra_pretrainer_test.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA pre trainer network."""
"""Tests for ELECTRA pre trainer network."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Implement Seq2Seq Transformer model by TF official NLP library.
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
Model paper: https://arxiv.org/pdf/1706.03762.pdf
...
...
official/nlp/modeling/models/seq2seq_transformer_test.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Test Transformer model."""
"""Test Transformer model."""
from
absl
import
logging
from
absl
import
logging
...
...
official/nlp/modeling/models/xlnet.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""XLNet models."""
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
...
...
official/nlp/modeling/models/xlnet_test.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for XLNet classifier network."""
"""Tests for XLNet classifier network."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/nlp/modeling/networks/__init__.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Networks package definition."""
"""Networks package definition."""
from
official.nlp.modeling.networks.albert_encoder
import
AlbertEncoder
from
official.nlp.modeling.networks.albert_encoder
import
AlbertEncoder
from
official.nlp.modeling.networks.bert_encoder
import
BertEncoder
from
official.nlp.modeling.networks.bert_encoder
import
BertEncoder
...
...
official/nlp/modeling/networks/albert_encoder.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
collections
import
collections
...
...
official/nlp/modeling/networks/albert_encoder_test.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for ALBERT transformer-based text encoder network."""
"""Tests for ALBERT transformer-based text encoder network."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
official/nlp/modeling/networks/bert_encoder.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
collections
import
collections
...
@@ -65,7 +65,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
...
@@ -65,7 +65,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
keyed by `encoder_outputs`.
keyed by `encoder_outputs`.
output_range: The sequence output range, [0, output_range), by slicing the
output_range: The sequence output range, [0, output_range), by slicing the
target sequence of the last transformer layer. `None` means the entire
target sequence of the last transformer layer. `None` means the entire
target sequence will attend to the source sequence, which y
e
ilds the full
target sequence will attend to the source sequence, which yi
e
lds the full
output.
output.
embedding_width: The width of the word embeddings. If the embedding width is
embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two
not equal to hidden size, embedding parameters will be factorized into two
...
...
official/nlp/modeling/networks/bert_encoder_test.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for transformer-based bert encoder network."""
"""Tests for transformer-based bert encoder network."""
# Import libraries
# Import libraries
...
...
official/nlp/modeling/networks/classification.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Classification and regression network."""
"""Classification and regression network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
collections
import
collections
...
...
official/nlp/modeling/networks/classification_test.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for classification network."""
"""Tests for classification network."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
...
...
official/nlp/modeling/networks/encoder_scaffold.py
View file @
bb124157
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Transformer-based text encoder network."""
"""Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
inspect
import
inspect
...
...
official/nlp/modeling/networks/encoder_scaffold_test.py
View file @
bb124157
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for EncoderScaffold network."""
"""Tests for EncoderScaffold network."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
...
...
official/nlp/modeling/networks/mobile_bert_encoder.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""MobileBERT text encoder network."""
"""MobileBERT text encoder network."""
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -43,6 +43,7 @@ class MobileBERTEncoder(tf.keras.Model):
...
@@ -43,6 +43,7 @@ class MobileBERTEncoder(tf.keras.Model):
num_feedforward_networks
=
4
,
num_feedforward_networks
=
4
,
normalization_type
=
'no_norm'
,
normalization_type
=
'no_norm'
,
classifier_activation
=
False
,
classifier_activation
=
False
,
input_mask_dtype
=
'int32'
,
**
kwargs
):
**
kwargs
):
"""Class initialization.
"""Class initialization.
...
@@ -76,6 +77,11 @@ class MobileBERTEncoder(tf.keras.Model):
...
@@ -76,6 +77,11 @@ class MobileBERTEncoder(tf.keras.Model):
MobileBERT paper. 'layer_norm' is used for the teacher model.
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: If using the tanh activation for the final
classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
representation of the [CLS] token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask`
tensor with values in float32 to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments.
**kwargs: Other keyworded and arguments.
"""
"""
self
.
_self_setattr_tracking
=
False
self
.
_self_setattr_tracking
=
False
...
@@ -115,11 +121,14 @@ class MobileBERTEncoder(tf.keras.Model):
...
@@ -115,11 +121,14 @@ class MobileBERTEncoder(tf.keras.Model):
input_ids
=
tf
.
keras
.
layers
.
Input
(
input_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
shape
=
(
None
,),
dtype
=
input_mask_dtype
,
name
=
'input_mask'
)
type_ids
=
tf
.
keras
.
layers
.
Input
(
type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
self
.
inputs
=
[
input_ids
,
input_mask
,
type_ids
]
self
.
inputs
=
[
input_ids
,
input_mask
,
type_ids
]
attention_mask
=
keras_nlp
.
layers
.
SelfAttentionMask
()(
input_ids
,
input_mask
)
# The dtype of `attention_mask` will the same as the dtype of `input_mask`.
attention_mask
=
keras_nlp
.
layers
.
SelfAttentionMask
()(
input_mask
,
input_mask
)
# build the computation graph
# build the computation graph
all_layer_outputs
=
[]
all_layer_outputs
=
[]
...
...
official/nlp/modeling/networks/mobile_bert_encoder_test.py
View file @
bb124157
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
...
@@ -21,7 +21,7 @@ from official.nlp.modeling.networks import mobile_bert_encoder
...
@@ -21,7 +21,7 @@ from official.nlp.modeling.networks import mobile_bert_encoder
def
generate_fake_input
(
batch_size
=
1
,
seq_len
=
5
,
vocab_size
=
10000
,
seed
=
0
):
def
generate_fake_input
(
batch_size
=
1
,
seq_len
=
5
,
vocab_size
=
10000
,
seed
=
0
):
"""Generate consis
ita
nt fake integer input sequences."""
"""Generate consis
te
nt fake integer input sequences."""
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
fake_input
=
[]
fake_input
=
[]
for
_
in
range
(
batch_size
):
for
_
in
range
(
batch_size
):
...
@@ -89,7 +89,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -89,7 +89,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertIsInstance
(
all_layer_output
,
list
)
self
.
assertIsInstance
(
all_layer_output
,
list
)
self
.
assertLen
(
all_layer_output
,
num_blocks
+
1
)
self
.
assertLen
(
all_layer_output
,
num_blocks
+
1
)
def
test_mobilebert_encoder_invocation
(
self
):
@
parameterized
.
parameters
(
'int32'
,
'float32'
)
def
test_mobilebert_encoder_invocation
(
self
,
input_mask_dtype
):
vocab_size
=
100
vocab_size
=
100
hidden_size
=
32
hidden_size
=
32
sequence_length
=
16
sequence_length
=
16
...
@@ -97,10 +98,11 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -97,10 +98,11 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
test_network
=
mobile_bert_encoder
.
MobileBERTEncoder
(
test_network
=
mobile_bert_encoder
.
MobileBERTEncoder
(
word_vocab_size
=
vocab_size
,
word_vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
num_blocks
=
num_blocks
)
num_blocks
=
num_blocks
,
input_mask_dtype
=
input_mask_dtype
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
input_mask_dtype
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
outputs
=
test_network
([
word_ids
,
mask
,
type_ids
])
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
outputs
)
model
=
tf
.
keras
.
Model
([
word_ids
,
mask
,
type_ids
],
outputs
)
...
...
official/nlp/modeling/networks/packed_sequence_embedding.py
View file @
bb124157
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""An embedding network supporting packed sequences and position ids."""
"""An embedding network supporting packed sequences and position ids."""
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-classes-have-attributes
import
collections
import
collections
...
...
official/nlp/modeling/networks/packed_sequence_embedding_test.py
View file @
bb124157
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.modeling.networks.packed_sequence_embedding."""
"""Tests for official.nlp.modeling.networks.packed_sequence_embedding."""
# Import libraries
# Import libraries
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment