Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f16a7b5b
Unverified
Commit
f16a7b5b
authored
May 04, 2021
by
vedanshu
Committed by
GitHub
May 04, 2021
Browse files
Merge pull request
#1
from tensorflow/master
new pull
parents
8e9296ff
8f58f396
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
1885 additions
and
300 deletions
+1885
-300
official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py
...g/losses/weighted_sparse_categorical_crossentropy_test.py
+4
-8
official/nlp/modeling/models/README.md
official/nlp/modeling/models/README.md
+8
-5
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+13
-4
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+70
-35
official/nlp/modeling/models/bert_classifier_test.py
official/nlp/modeling/models/bert_classifier_test.py
+25
-18
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+101
-51
official/nlp/modeling/models/bert_pretrainer_test.py
official/nlp/modeling/models/bert_pretrainer_test.py
+99
-22
official/nlp/modeling/models/bert_span_labeler.py
official/nlp/modeling/models/bert_span_labeler.py
+46
-24
official/nlp/modeling/models/bert_span_labeler_test.py
official/nlp/modeling/models/bert_span_labeler_test.py
+10
-15
official/nlp/modeling/models/bert_token_classifier.py
official/nlp/modeling/models/bert_token_classifier.py
+71
-35
official/nlp/modeling/models/bert_token_classifier_test.py
official/nlp/modeling/models/bert_token_classifier_test.py
+27
-17
official/nlp/modeling/models/dual_encoder.py
official/nlp/modeling/models/dual_encoder.py
+162
-0
official/nlp/modeling/models/dual_encoder_test.py
official/nlp/modeling/models/dual_encoder_test.py
+125
-0
official/nlp/modeling/models/electra_pretrainer.py
official/nlp/modeling/models/electra_pretrainer.py
+47
-44
official/nlp/modeling/models/electra_pretrainer_test.py
official/nlp/modeling/models/electra_pretrainer_test.py
+20
-22
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+589
-0
official/nlp/modeling/models/seq2seq_transformer_test.py
official/nlp/modeling/models/seq2seq_transformer_test.py
+126
-0
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+342
-0
No files found.
Too many changes to show.
To preserve performance only
298 of 298+
files are displayed.
Plain diff
Email patch
official/nlp/modeling/losses/weighted_sparse_categorical_crossentropy_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,13 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for masked LM loss."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for masked LM loss."""
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -39,7 +34,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output
=
"predictions"
):
# First, create a transformer stack that we can use to get the LM's
# vocabulary weight.
xformer_stack
=
networks
.
Transform
erEncoder
(
xformer_stack
=
networks
.
B
er
t
Encoder
(
vocab_size
=
vocab_size
,
num_layers
=
1
,
sequence_length
=
sequence_length
,
...
...
@@ -204,5 +199,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
expected_loss_data
=
6.4222
self
.
assertAllClose
(
expected_loss_data
,
loss_data
,
rtol
=
1e-3
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/models/README.md
View file @
f16a7b5b
# Models
Models are combinations of layers and
network
s that
would
be trained.
Models are combinations of
`tf.keras`
layers and
model
s that
can
be trained.
Several pre-built canned models are provided to train encoder networks.
These
models are intended as both convenience functions and canonical examples.
Several pre-built canned models are provided to train encoder networks.
These
models are intended as both convenience functions and canonical examples.
*
[
`BertClassifier`
](
bert_classifier.py
)
implements a simple classification
model containing a single classification head using the Classification network.
It can be used as a regression model as well.
*
[
`BertTokenClassifier`
](
bert_token_classifier.py
)
implements a simple token
classification model containing a single classification head
using th
e
TokenClassification network
.
classification model containing a single classification head
over the sequenc
e
output embeddings
.
*
[
`BertSpanLabeler`
](
bert_span_labeler.py
)
implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token
...
...
@@ -20,3 +20,6 @@ index and an end token index), suitable for SQuAD-style tasks.
*
[
`BertPretrainer`
](
bert_pretrainer.py
)
implements a masked LM and a
classification head using the Masked LM and Classification networks,
respectively.
*
[
`DualEncoder`
](
dual_encoder.py
)
implements a dual encoder model, suitbale for
retrieval tasks.
official/nlp/modeling/models/__init__.py
View file @
f16a7b5b
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,10 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Models package definition."""
"""Models are combinations of `tf.keras` layers and models that can be trained.
Several pre-built canned models are provided to train encoder networks.
These models are intended as both convenience functions and canonical examples.
"""
from
official.nlp.modeling.models.bert_classifier
import
BertClassifier
from
official.nlp.modeling.models.bert_pretrainer
import
BertPretrainer
from
official.nlp.modeling.models.bert_pretrainer
import
*
from
official.nlp.modeling.models.bert_span_labeler
import
BertSpanLabeler
from
official.nlp.modeling.models.bert_token_classifier
import
BertTokenClassifier
from
official.nlp.modeling.models.dual_encoder
import
DualEncoder
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.seq2seq_transformer
import
*
from
official.nlp.modeling.models.xlnet
import
XLNetClassifier
from
official.nlp.modeling.models.xlnet
import
XLNetPretrainer
from
official.nlp.modeling.models.xlnet
import
XLNetSpanLabeler
official/nlp/modeling/models/bert_classifier.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,18 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes
import
collections
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
...
...
@@ -37,7 +32,10 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
...
...
@@ -45,8 +43,12 @@ class BertClassifier(tf.keras.Model):
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
"""
def
__init__
(
self
,
...
...
@@ -55,15 +57,11 @@ class BertClassifier(tf.keras.Model):
initializer
=
'glorot_uniform'
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
cls_head
=
None
,
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_network
=
network
self
.
_config
=
{
'network'
:
network
,
'num_classes'
:
num_classes
,
'initializer'
:
initializer
,
'use_encoder_pooler'
:
use_encoder_pooler
,
}
self
.
num_classes
=
num_classes
self
.
initializer
=
initializer
self
.
use_encoder_pooler
=
use_encoder_pooler
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
...
...
@@ -73,36 +71,73 @@ class BertClassifier(tf.keras.Model):
if
use_encoder_pooler
:
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
_
,
cls_output
=
network
(
inputs
)
cls_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
cls_output
)
outputs
=
network
(
inputs
)
if
isinstance
(
outputs
,
list
):
cls_inputs
=
outputs
[
1
]
else
:
cls_inputs
=
outputs
[
'pooled_output'
]
cls_inputs
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
cls_inputs
)
else
:
outputs
=
network
(
inputs
)
if
isinstance
(
outputs
,
list
):
cls_inputs
=
outputs
[
0
]
else
:
cls_inputs
=
outputs
[
'sequence_output'
]
self
.
classifier
=
networks
.
Classification
(
input_width
=
cls_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
initializer
=
initializer
,
output
=
'logits'
,
name
=
'sentence_prediction'
)
predictions
=
self
.
classifier
(
cls_output
)
if
cls_head
:
classifier
=
cls_head
else
:
sequence_output
,
_
=
network
(
inputs
)
self
.
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
sequence_output
.
shape
[
-
1
],
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
0
if
use_encoder_pooler
else
cls_inputs
.
shape
[
-
1
],
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
predictions
=
self
.
classifier
(
sequence_output
)
predictions
=
classifier
(
cls_inputs
)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
BertClassifier
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
self
.
_network
=
network
self
.
_cls_head
=
cls_head
config_dict
=
self
.
_make_config_dict
()
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
classifier
=
classifier
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
items
=
dict
(
encoder
=
self
.
_network
)
if
hasattr
(
self
.
classifier
,
'checkpoint_items'
):
for
key
,
item
in
self
.
classifier
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
self
.
classifier
.
name
,
key
])]
=
item
return
items
def
get_config
(
self
):
return
self
.
_config
return
dict
(
self
.
_config
.
_asdict
())
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
def
_make_config_dict
(
self
):
return
{
'network'
:
self
.
_network
,
'num_classes'
:
self
.
num_classes
,
'initializer'
:
self
.
initializer
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'cls_head'
:
self
.
_cls_head
,
}
official/nlp/modeling/models/bert_classifier_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for BERT trainer network."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
bert_classifier
...
...
@@ -31,14 +28,15 @@ from official.nlp.modeling.models import bert_classifier
@
keras_parameterized
.
run_all_keras_modes
class
BertClassifierTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
(
1
,
3
)
def
test_bert_trainer
(
self
,
num_classes
):
@
parameterized
.
named_parameters
((
'single_cls'
,
1
,
False
),
(
'3_cls'
,
3
,
False
),
(
'3_cls_dictoutputs'
,
3
,
True
))
def
test_bert_trainer
(
self
,
num_classes
,
dict_outputs
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
dict_outputs
=
dict_outputs
)
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
...
...
@@ -56,17 +54,22 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape
=
[
None
,
num_classes
]
self
.
assertAllEqual
(
expected_classification_shape
,
cls_outs
.
shape
.
as_list
())
@
parameterized
.
parameters
(
1
,
2
)
def
test_bert_trainer_tensor_call
(
self
,
num_classes
):
@
parameterized
.
named_parameters
(
(
'single_cls'
,
1
,
False
),
(
'2_cls'
,
2
,
False
),
(
'single_cls_custom_head'
,
1
,
True
),
(
'2_cls_custom_head'
,
2
,
True
))
def
test_bert_trainer_tensor_call
(
self
,
num_classes
,
use_custom_head
):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
cls_head
=
layers
.
GaussianProcessClassificationHead
(
inner_dim
=
0
,
num_classes
=
num_classes
)
if
use_custom_head
else
None
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
test_network
,
num_classes
=
num_classes
)
test_network
,
num_classes
=
num_classes
,
cls_head
=
cls_head
)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
...
...
@@ -78,17 +81,21 @@ class BertClassifierTest(keras_parameterized.TestCase):
# too complex: this simply ensures we're not hitting runtime errors.)
_
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
])
def
test_serialize_deserialize
(
self
):
@
parameterized
.
named_parameters
(
(
'default_cls_head'
,
None
),
(
'sngp_cls_head'
,
layers
.
GaussianProcessClassificationHead
(
inner_dim
=
0
,
num_classes
=
4
)))
def
test_serialize_deserialize
(
self
,
cls_head
):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
bert_trainer_model
=
bert_classifier
.
BertClassifier
(
test_network
,
num_classes
=
4
,
initializer
=
'zeros'
)
test_network
,
num_classes
=
4
,
initializer
=
'zeros'
,
cls_head
=
cls_head
)
# Create another BERT trainer via serialization and deserialization.
config
=
bert_trainer_model
.
get_config
()
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes
import
collections
import
copy
from
typing
import
List
,
Optional
from
absl
import
logging
import
gin
import
tensorflow
as
tf
...
...
@@ -31,17 +28,18 @@ from official.nlp.modeling import networks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BertPretrainer
(
tf
.
keras
.
Model
):
"""BERT
network
training model.
"""BERT
pre
training model.
This is an implementation of the network structure surrounding a transformer
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
[Note] Please use the new `BertPretrainerV2` for your projects.
The BertPretrainer allows a user to pass in a transformer stack, and
instantiates the masked language model and classification networks that are
used to create the training objectives.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output.
num_classes: Number of classes to predict from the classification network.
...
...
@@ -52,8 +50,8 @@ class BertPretrainer(tf.keras.Model):
None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either
'
logits
'
or
'
predictions
'
.
output: The output style for this network. Can be either
`
logits
`
or
`
predictions
`
.
"""
def
__init__
(
self
,
...
...
@@ -65,21 +63,12 @@ class BertPretrainer(tf.keras.Model):
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_config
=
{
'network'
:
network
,
'num_classes'
:
num_classes
,
'num_token_predictions'
:
num_token_predictions
,
'activation'
:
activation
,
'initializer'
:
initializer
,
'output'
:
output
,
}
self
.
encoder
=
network
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.)
network_inputs
=
self
.
encoder
.
inputs
network_inputs
=
network
.
inputs
inputs
=
copy
.
copy
(
network_inputs
)
# Because we have a copy of inputs to create this Model object, we can
...
...
@@ -87,7 +76,7 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below.
sequence_output
,
cls_output
=
self
.
encoder
(
network_inputs
)
sequence_output
,
cls_output
=
network
(
network_inputs
)
# The encoder network may get outputs from all layers.
if
isinstance
(
sequence_output
,
list
):
...
...
@@ -95,7 +84,8 @@ class BertPretrainer(tf.keras.Model):
if
isinstance
(
cls_output
,
list
):
cls_output
=
cls_output
[
-
1
]
sequence_output_length
=
sequence_output
.
shape
.
as_list
()[
1
]
if
sequence_output_length
<
num_token_predictions
:
if
sequence_output_length
is
not
None
and
(
sequence_output_length
<
num_token_predictions
):
raise
ValueError
(
"The passed network's output length is %s, which is less than the "
'requested num_token_predictions %s.'
%
...
...
@@ -108,48 +98,74 @@ class BertPretrainer(tf.keras.Model):
inputs
.
append
(
masked_lm_positions
)
if
embedding_table
is
None
:
embedding_table
=
self
.
encoder
.
get_embedding_table
()
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
network
.
get_embedding_table
()
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
embedding_table
,
activation
=
activation
,
initializer
=
initializer
,
output
=
output
,
name
=
'cls/predictions'
)
lm_outputs
=
self
.
masked_lm
(
lm_outputs
=
masked_lm
(
sequence_output
,
masked_positions
=
masked_lm_positions
)
self
.
classification
=
networks
.
Classification
(
classification
=
networks
.
Classification
(
input_width
=
cls_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
initializer
=
initializer
,
output
=
output
,
name
=
'classification'
)
sentence_outputs
=
self
.
classification
(
cls_output
)
sentence_outputs
=
classification
(
cls_output
)
super
(
BertPretrainer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
dict
(
masked_lm
=
lm_outputs
,
classification
=
sentence_outputs
),
**
kwargs
)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
config_dict
=
{
'network'
:
network
,
'num_classes'
:
num_classes
,
'num_token_predictions'
:
num_token_predictions
,
'activation'
:
activation
,
'initializer'
:
initializer
,
'output'
:
output
,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
encoder
=
network
self
.
classification
=
classification
self
.
masked_lm
=
masked_lm
def
get_config
(
self
):
return
self
.
_config
return
dict
(
self
.
_config
.
_asdict
())
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
# TODO(hongkuny): Migrate to BertPretrainerV2 for all usages.
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
gin
.
configurable
class
BertPretrainerV2
(
tf
.
keras
.
Model
):
"""BERT pretraining model V2.
(Experimental).
Adds the masked language model head and optional classification heads upon the
transformer encoder.
Arg
ument
s:
Args:
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network. If
...
...
@@ -158,11 +174,16 @@ class BertPretrainerV2(tf.keras.Model):
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
sequence outputs.
customized_masked_lm: A customized masked_lm layer. If None, will create
a standard layer from `layers.MaskedLM`; if not None, will use the
specified masked_lm layer. Above arguments `mlm_activation` and
`mlm_initializer` will be ignored.
name: The name of the model.
Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
dictionary.
Outputs: A dictionary of `lm_output` and classification head outputs keyed by
head names.
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names, and also outputs from `encoder_network`, keyed by
`sequence_output` and `encoder_outputs` (if any).
"""
def
__init__
(
...
...
@@ -171,27 +192,24 @@ class BertPretrainerV2(tf.keras.Model):
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
classification_heads
:
Optional
[
List
[
tf
.
keras
.
layers
.
Layer
]]
=
None
,
customized_masked_lm
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
name
:
str
=
'bert'
,
**
kwargs
):
s
elf
.
_self_setattr_tracking
=
False
s
uper
().
__init__
(
self
,
name
=
name
,
**
kwargs
)
self
.
_config
=
{
'encoder_network'
:
encoder_network
,
'mlm_initializer'
:
mlm_initializer
,
'classification_heads'
:
classification_heads
,
'name'
:
name
,
}
self
.
encoder_network
=
encoder_network
inputs
=
copy
.
copy
(
self
.
encoder_network
.
inputs
)
sequence_output
,
_
=
self
.
encoder_network
(
inputs
)
self
.
classification_heads
=
classification_heads
or
[]
if
len
(
set
([
cls
.
name
for
cls
in
self
.
classification_heads
]))
!=
len
(
self
.
classification_heads
):
raise
ValueError
(
'Classification heads should have unique names.'
)
outputs
=
dict
()
self
.
masked_lm
=
layers
.
MaskedLM
(
self
.
masked_lm
=
customized_masked_lm
or
layers
.
MaskedLM
(
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
...
...
@@ -199,13 +217,45 @@ class BertPretrainerV2(tf.keras.Model):
masked_lm_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
name
=
'masked_lm_positions'
,
dtype
=
tf
.
int32
)
inputs
.
append
(
masked_lm_positions
)
outputs
[
'lm_output'
]
=
self
.
masked_lm
(
sequence_output
,
masked_positions
=
masked_lm_positions
)
for
cls_head
in
self
.
classification_heads
:
outputs
[
cls_head
.
name
]
=
cls_head
(
sequence_output
)
self
.
inputs
=
inputs
def
call
(
self
,
inputs
):
if
isinstance
(
inputs
,
list
):
logging
.
warning
(
'List inputs to BertPretrainer are discouraged.'
)
inputs
=
dict
([
(
ref
.
name
,
tensor
)
for
ref
,
tensor
in
zip
(
self
.
inputs
,
inputs
)
])
super
(
BertPretrainerV2
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
name
=
name
,
**
kwargs
)
outputs
=
dict
()
encoder_network_outputs
=
self
.
encoder_network
(
inputs
)
if
isinstance
(
encoder_network_outputs
,
list
):
outputs
[
'pooled_output'
]
=
encoder_network_outputs
[
1
]
# When `encoder_network` was instantiated with return_all_encoder_outputs
# set to True, `encoder_network_outputs[0]` is a list containing
# all transformer layers' output.
if
isinstance
(
encoder_network_outputs
[
0
],
list
):
outputs
[
'encoder_outputs'
]
=
encoder_network_outputs
[
0
]
outputs
[
'sequence_output'
]
=
encoder_network_outputs
[
0
][
-
1
]
else
:
outputs
[
'sequence_output'
]
=
encoder_network_outputs
[
0
]
elif
isinstance
(
encoder_network_outputs
,
dict
):
outputs
=
encoder_network_outputs
else
:
raise
ValueError
(
'encoder_network
\'
s output should be either a list '
'or a dict, but got %s'
%
encoder_network_outputs
)
sequence_output
=
outputs
[
'sequence_output'
]
# Inference may not have masked_lm_positions and mlm_logits is not needed.
if
'masked_lm_positions'
in
inputs
:
masked_lm_positions
=
inputs
[
'masked_lm_positions'
]
outputs
[
'mlm_logits'
]
=
self
.
masked_lm
(
sequence_output
,
masked_positions
=
masked_lm_positions
)
for
cls_head
in
self
.
classification_heads
:
cls_outputs
=
cls_head
(
sequence_output
)
if
isinstance
(
cls_outputs
,
dict
):
outputs
.
update
(
cls_outputs
)
else
:
outputs
[
cls_head
.
name
]
=
cls_outputs
return
outputs
@
property
def
checkpoint_items
(
self
):
...
...
official/nlp/modeling/models/bert_pretrainer_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for BERT pretrainer model."""
import
itertools
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
bert_pretrainer
...
...
@@ -35,8 +34,10 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
max_sequence_length
=
sequence_length
)
# Create a BERT trainer with the created network.
num_classes
=
3
...
...
@@ -68,7 +69,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
2
)
# Create a BERT trainer with the created network.
...
...
@@ -90,8 +91,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
max_
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
...
...
@@ -109,36 +110,112 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
bert_trainer_model
.
get_config
(),
new_bert_trainer_model
.
get_config
())
def
test_bert_pretrainerv2
(
self
):
class
BertPretrainerV2Test
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
(
False
,
True
),
(
False
,
True
),
(
False
,
True
),
(
False
,
True
),
))
def
test_bert_pretrainerv2
(
self
,
dict_outputs
,
return_all_encoder_outputs
,
use_customized_masked_lm
,
has_masked_lm_positions
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
hidden_size
=
48
num_layers
=
2
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
num_layers
,
hidden_size
=
hidden_size
,
max_sequence_length
=
sequence_length
,
return_all_encoder_outputs
=
return_all_encoder_outputs
,
dict_outputs
=
dict_outputs
)
# Create a BERT trainer with the created network.
if
use_customized_masked_lm
:
customized_masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
test_network
.
get_embedding_table
())
else
:
customized_masked_lm
=
None
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
encoder_network
=
test_network
)
encoder_network
=
test_network
,
customized_masked_lm
=
customized_masked_lm
)
num_token_predictions
=
20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
lm_mask
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
))
if
has_masked_lm_positions
:
inputs
[
'masked_lm_positions'
]
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
,
lm_mask
])
outputs
=
bert_trainer_model
(
inputs
)
has_encoder_outputs
=
dict_outputs
or
return_all_encoder_outputs
expected_keys
=
[
'sequence_output'
,
'pooled_output'
]
if
has_encoder_outputs
:
expected_keys
.
append
(
'encoder_outputs'
)
if
has_masked_lm_positions
:
expected_keys
.
append
(
'mlm_logits'
)
self
.
assertSameElements
(
outputs
.
keys
(),
expected_keys
)
# Validate that the outputs are of the expected shape.
expected_lm_shape
=
[
None
,
num_token_predictions
,
vocab_size
]
self
.
assertAllEqual
(
expected_lm_shape
,
outputs
[
'lm_output'
].
shape
.
as_list
())
if
has_masked_lm_positions
:
self
.
assertAllEqual
(
expected_lm_shape
,
outputs
[
'mlm_logits'
].
shape
.
as_list
())
expected_sequence_output_shape
=
[
None
,
sequence_length
,
hidden_size
]
self
.
assertAllEqual
(
expected_sequence_output_shape
,
outputs
[
'sequence_output'
].
shape
.
as_list
())
expected_pooled_output_shape
=
[
None
,
hidden_size
]
self
.
assertAllEqual
(
expected_pooled_output_shape
,
outputs
[
'pooled_output'
].
shape
.
as_list
())
def
test_multiple_cls_outputs
(
self
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
hidden_size
=
48
num_layers
=
2
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
num_layers
,
hidden_size
=
hidden_size
,
max_sequence_length
=
sequence_length
,
dict_outputs
=
True
)
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
encoder_network
=
test_network
,
classification_heads
=
[
layers
.
MultiClsHeads
(
inner_dim
=
5
,
cls_list
=
[(
'foo'
,
2
),
(
'bar'
,
3
)])])
num_token_predictions
=
20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
masked_lm_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
bert_trainer_model
(
inputs
)
self
.
assertEqual
(
outputs
[
'foo'
].
shape
.
as_list
(),
[
None
,
2
])
self
.
assertEqual
(
outputs
[
'bar'
].
shape
.
as_list
(),
[
None
,
3
])
def
test_v2_serialize_deserialize
(
self
):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
...
...
official/nlp/modeling/models/bert_span_labeler.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,14 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes
import
collections
import
tensorflow
as
tf
from
official.nlp.modeling
import
networks
...
...
@@ -32,17 +28,20 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer
stack
, and
The BertSpanLabeler allows a user to pass in a transformer
encoder
, and
instantiates a span labeling network based on a single dense layer.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a
"
get_embedding_table
"
method.
table via a
`
get_embedding_table
`
method.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either
'
logit
s
' or
'
predictions
'
.
output: The output style for this network. Can be either
`
logit
`
' or
`
predictions
`
.
"""
def
__init__
(
self
,
...
...
@@ -50,13 +49,6 @@ class BertSpanLabeler(tf.keras.Model):
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_network
=
network
self
.
_config
=
{
'network'
:
network
,
'initializer'
:
initializer
,
'output'
:
output
,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
...
...
@@ -65,16 +57,25 @@ class BertSpanLabeler(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output
,
_
=
network
(
inputs
)
outputs
=
network
(
inputs
)
if
isinstance
(
outputs
,
list
):
sequence_output
=
outputs
[
0
]
else
:
sequence_output
=
outputs
[
'sequence_output'
]
# The input network (typically a transformer model) may get outputs from all
# layers. When this case happens, we retrieve the last layer output.
if
isinstance
(
sequence_output
,
list
):
sequence_output
=
sequence_output
[
-
1
]
# This is an instance variable for ease of access to the underlying task
# network.
self
.
span_labeling
=
networks
.
SpanLabeling
(
span_labeling
=
networks
.
SpanLabeling
(
input_width
=
sequence_output
.
shape
[
-
1
],
initializer
=
initializer
,
output
=
output
,
name
=
'span_labeling'
)
start_logits
,
end_logits
=
self
.
span_labeling
(
sequence_output
)
start_logits
,
end_logits
=
span_labeling
(
sequence_output
)
# Use identity layers wrapped in lambdas to explicitly name the output
# tensors. This allows us to use string-keyed dicts in Keras fit/predict/
...
...
@@ -88,15 +89,36 @@ class BertSpanLabeler(tf.keras.Model):
logits
=
[
start_logits
,
end_logits
]
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
BertSpanLabeler
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
logits
,
**
kwargs
)
self
.
_network
=
network
config_dict
=
{
'network'
:
network
,
'initializer'
:
initializer
,
'output'
:
output
,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
span_labeling
=
span_labeling
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
return
self
.
_config
return
dict
(
self
.
_config
.
_asdict
())
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
...
...
official/nlp/modeling/models/bert_span_labeler_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for BERT trainer network."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
...
...
@@ -30,13 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler
@
keras_parameterized
.
run_all_keras_modes
class
BertSpanLabelerTest
(
keras_parameterized
.
TestCase
):
def
test_bert_trainer
(
self
):
@
parameterized
.
parameters
(
True
,
False
)
def
test_bert_trainer
(
self
,
dict_outputs
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
dict_outputs
=
dict_outputs
)
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_span_labeler
.
BertSpanLabeler
(
test_network
)
...
...
@@ -59,9 +57,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate compilation using explicit output names."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
)
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_span_labeler
.
BertSpanLabeler
(
test_network
)
...
...
@@ -80,8 +76,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
2
)
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
)
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_span_labeler
.
BertSpanLabeler
(
test_network
)
...
...
@@ -100,7 +95,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
...
...
official/nlp/modeling/models/bert_token_classifier.py
View file @
f16a7b5b
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,18 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""BERT token classifier."""
# pylint: disable=g-classes-have-attributes
import
collections
import
tensorflow
as
tf
from
official.nlp.modeling
import
networks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
BertTokenClassifier
(
tf
.
keras
.
Model
):
...
...
@@ -36,15 +30,21 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes`
argument.
Arguments:
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a
"
get_embedding_table
"
method.
table via a
`
get_embedding_table
`
method.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
output: The output style for this network. Can be either `logits` or
`predictions`.
dropout_rate: The dropout probability of the token classification head.
output_encoder_outputs: Whether to include intermediate sequence output
in the final output.
"""
def
__init__
(
self
,
...
...
@@ -53,15 +53,8 @@ class BertTokenClassifier(tf.keras.Model):
initializer
=
'glorot_uniform'
,
output
=
'logits'
,
dropout_rate
=
0.1
,
output_encoder_outputs
=
False
,
**
kwargs
):
self
.
_self_setattr_tracking
=
False
self
.
_network
=
network
self
.
_config
=
{
'network'
:
network
,
'num_classes'
:
num_classes
,
'initializer'
:
initializer
,
'output'
:
output
,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
...
...
@@ -70,27 +63,70 @@ class BertTokenClassifier(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output
,
_
=
network
(
inputs
)
sequence_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
sequence_output
)
self
.
classifier
=
networks
.
TokenClassification
(
input_width
=
sequence_output
.
shape
[
-
1
],
num_classes
=
num_classes
,
initializer
=
initializer
,
output
=
output
,
name
=
'classification'
)
predictions
=
self
.
classifier
(
sequence_output
)
outputs
=
network
(
inputs
)
if
isinstance
(
outputs
,
list
):
sequence_output
=
outputs
[
0
]
else
:
sequence_output
=
outputs
[
'sequence_output'
]
sequence_output
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
dropout_rate
)(
sequence_output
)
classifier
=
tf
.
keras
.
layers
.
Dense
(
num_classes
,
activation
=
None
,
kernel_initializer
=
initializer
,
name
=
'predictions/transform/logits'
)
logits
=
classifier
(
sequence_output
)
if
output
==
'logits'
:
output_tensors
=
{
'logits'
:
logits
}
elif
output
==
'predictions'
:
output_tensors
=
{
'predictions'
:
tf
.
keras
.
layers
.
Activation
(
tf
.
nn
.
log_softmax
)(
logits
)
}
else
:
raise
ValueError
(
(
'Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"'
)
%
output
)
if
output_encoder_outputs
:
output_tensors
[
'encoder_outputs'
]
=
sequence_output
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
BertTokenClassifier
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
predictions
,
**
kwargs
)
inputs
=
inputs
,
outputs
=
output_tensors
,
**
kwargs
)
self
.
_network
=
network
config_dict
=
{
'network'
:
network
,
'num_classes'
:
num_classes
,
'initializer'
:
initializer
,
'output'
:
output
,
'output_encoder_outputs'
:
output_encoder_outputs
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
classifier
=
classifier
self
.
logits
=
logits
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
return
self
.
_config
return
dict
(
self
.
_config
.
_asdict
())
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
...
...
official/nlp/modeling/models/bert_token_classifier_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,13 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for BERT token classifier."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
...
...
@@ -30,19 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier
@
keras_parameterized
.
run_all_keras_modes
class
BertTokenClassifierTest
(
keras_parameterized
.
TestCase
):
def
test_bert_trainer
(
self
):
@
parameterized
.
parameters
((
True
,
True
),
(
False
,
False
))
def
test_bert_trainer
(
self
,
dict_outputs
,
output_encoder_outputs
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
hidden_size
=
768
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
max_sequence_length
=
sequence_length
,
dict_outputs
=
dict_outputs
,
hidden_size
=
hidden_size
)
# Create a BERT trainer with the created network.
num_classes
=
3
bert_trainer_model
=
bert_token_classifier
.
BertTokenClassifier
(
test_network
,
num_classes
=
num_classes
)
num_classes
=
num_classes
,
output_encoder_outputs
=
output_encoder_outputs
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
...
...
@@ -50,19 +54,25 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
# Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
])
outputs
=
bert_trainer_model
([
word_ids
,
mask
,
type_ids
])
if
output_encoder_outputs
:
logits
=
outputs
[
'logits'
]
encoder_outputs
=
outputs
[
'encoder_outputs'
]
self
.
assertAllEqual
(
encoder_outputs
.
shape
.
as_list
(),
[
None
,
sequence_length
,
hidden_size
])
else
:
logits
=
outputs
[
'logits'
]
# Validate that the outputs are of the expected shape.
expected_classification_shape
=
[
None
,
sequence_length
,
num_classes
]
self
.
assertAllEqual
(
expected_classification_shape
,
sequence_outs
.
shape
.
as_list
())
self
.
assertAllEqual
(
expected_classification_shape
,
logits
.
shape
.
as_list
())
def
test_bert_trainer_tensor_call
(
self
):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
2
)
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
max_
sequence_length
=
2
)
# Create a BERT trainer with the created network.
bert_trainer_model
=
bert_token_classifier
.
BertTokenClassifier
(
...
...
@@ -82,8 +92,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
5
)
test_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
2
,
max_
sequence_length
=
5
)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
...
...
official/nlp/modeling/models/dual_encoder.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
import
collections
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
DualEncoder
(
tf
.
keras
.
Model
):
"""A dual encoder model based on a transformer-based encoder.
This is an implementation of the dual encoder network structure based on the
transfomer stack, as described in ["Language-agnostic BERT Sentence
Embedding"](https://arxiv.org/abs/2007.01852)
The DualEncoder allows a user to pass in a transformer stack, and build a dual
encoder model based on the transformer stack.
Args:
network: A transformer network which should output an encoding output.
max_seq_length: The maximum allowed sequence length for transformer.
normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either `logits` or
`predictions`. If set to `predictions`, it will output the embedding
producted by transformer network.
"""
def
__init__
(
self
,
network
:
tf
.
keras
.
Model
,
max_seq_length
:
int
=
32
,
normalize
:
bool
=
True
,
logit_scale
:
float
=
1.0
,
logit_margin
:
float
=
0.0
,
output
:
str
=
'logits'
,
**
kwargs
)
->
None
:
if
output
==
'logits'
:
left_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_word_ids'
)
left_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_mask'
)
left_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'left_type_ids'
)
else
:
# Keep the consistant with legacy BERT hub module input names.
left_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
left_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
left_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
left_inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
]
left_outputs
=
network
(
left_inputs
)
if
isinstance
(
left_outputs
,
list
):
left_sequence_output
,
left_encoded
=
left_outputs
else
:
left_sequence_output
=
left_outputs
[
'sequence_output'
]
left_encoded
=
left_outputs
[
'pooled_output'
]
if
normalize
:
left_encoded
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
tf
.
nn
.
l2_normalize
(
x
,
axis
=
1
))(
left_encoded
)
if
output
==
'logits'
:
right_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'right_word_ids'
)
right_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'right_mask'
)
right_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
max_seq_length
,),
dtype
=
tf
.
int32
,
name
=
'right_type_ids'
)
right_inputs
=
[
right_word_ids
,
right_mask
,
right_type_ids
]
right_outputs
=
network
(
right_inputs
)
if
isinstance
(
right_outputs
,
list
):
_
,
right_encoded
=
right_outputs
else
:
right_encoded
=
right_outputs
[
'pooled_output'
]
if
normalize
:
right_encoded
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
tf
.
nn
.
l2_normalize
(
x
,
axis
=
1
))(
right_encoded
)
dot_products
=
layers
.
MatMulWithMargin
(
logit_scale
=
logit_scale
,
logit_margin
=
logit_margin
,
name
=
'dot_product'
)
inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
,
right_word_ids
,
right_mask
,
right_type_ids
]
left_logits
,
right_logits
=
dot_products
(
left_encoded
,
right_encoded
)
outputs
=
dict
(
left_logits
=
left_logits
,
right_logits
=
right_logits
)
elif
output
==
'predictions'
:
inputs
=
[
left_word_ids
,
left_mask
,
left_type_ids
]
# To keep consistent with legacy BERT hub modules, the outputs are
# "pooled_output" and "sequence_output".
outputs
=
dict
(
sequence_output
=
left_sequence_output
,
pooled_output
=
left_encoded
)
else
:
raise
ValueError
(
'output type %s is not supported'
%
output
)
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
# constructor, resulting in this object having all the properties of a model
# created using the Functional API. Once super().__init__ is called, we
# can assign attributes to `self` - note that all `self` assignments are
# below this line.
super
(
DualEncoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
**
kwargs
)
config_dict
=
{
'network'
:
network
,
'max_seq_length'
:
max_seq_length
,
'normalize'
:
normalize
,
'logit_scale'
:
logit_scale
,
'logit_margin'
:
logit_margin
,
'output'
:
output
,
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
# compatibility with an earlier version of this model which did not track
# the config dict attribute. TF does not track immutable attrs which
# do not contain Trackables, so by creating a config namedtuple instead of
# a dict we avoid tracking it.
config_cls
=
collections
.
namedtuple
(
'Config'
,
config_dict
.
keys
())
self
.
_config
=
config_cls
(
**
config_dict
)
self
.
network
=
network
def
get_config
(
self
):
return
dict
(
self
.
_config
.
_asdict
())
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
network
)
return
items
official/nlp/modeling/models/dual_encoder_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dual encoder network."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.nlp.modeling
import
networks
from
official.nlp.modeling.models
import
dual_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@
keras_parameterized
.
run_all_keras_modes
class
DualEncoderTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
((
192
,
'logits'
),
(
768
,
'predictions'
))
def
test_dual_encoder
(
self
,
hidden_size
,
output
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the dual encoder model.
vocab_size
=
100
sequence_length
=
512
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
hidden_size
=
hidden_size
,
sequence_length
=
sequence_length
,
dict_outputs
=
True
)
# Create a dual encoder model with the created network.
dual_encoder_model
=
dual_encoder
.
DualEncoder
(
test_network
,
max_seq_length
=
sequence_length
,
output
=
output
)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
left_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
left_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
left_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
right_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
right_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
right_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
if
output
==
'logits'
:
outputs
=
dual_encoder_model
([
left_word_ids
,
left_mask
,
left_type_ids
,
right_word_ids
,
right_mask
,
right_type_ids
])
_
=
outputs
[
'left_logits'
]
elif
output
==
'predictions'
:
outputs
=
dual_encoder_model
([
left_word_ids
,
left_mask
,
left_type_ids
])
# Validate that the outputs are of the expected shape.
expected_sequence_shape
=
[
None
,
sequence_length
,
768
]
self
.
assertAllEqual
(
expected_sequence_shape
,
outputs
[
'sequence_output'
].
shape
.
as_list
())
left_encoded
=
outputs
[
'pooled_output'
]
expected_encoding_shape
=
[
None
,
768
]
self
.
assertAllEqual
(
expected_encoding_shape
,
left_encoded
.
shape
.
as_list
())
@
parameterized
.
parameters
((
192
,
'logits'
),
(
768
,
'predictions'
))
def
test_dual_encoder_tensor_call
(
self
,
hidden_size
,
output
):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use # a short sequence_length for convenience.)
sequence_length
=
2
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
# Create a dual encoder model with the created network.
dual_encoder_model
=
dual_encoder
.
DualEncoder
(
test_network
,
max_seq_length
=
sequence_length
,
output
=
output
)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
mask
=
tf
.
constant
([[
1
,
1
],
[
1
,
0
]],
dtype
=
tf
.
int32
)
type_ids
=
tf
.
constant
([[
1
,
1
],
[
2
,
2
]],
dtype
=
tf
.
int32
)
# Invoke the model model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
if
output
==
'logits'
:
_
=
dual_encoder_model
(
[
word_ids
,
mask
,
type_ids
,
word_ids
,
mask
,
type_ids
])
elif
output
==
'predictions'
:
_
=
dual_encoder_model
([
word_ids
,
mask
,
type_ids
])
def
test_serialize_deserialize
(
self
):
"""Validate that the dual encoder model can be serialized / deserialized."""
# Build a transformer network to use within the dual encoder model. (Here,
# we use a short sequence_length for convenience.)
sequence_length
=
32
test_network
=
networks
.
BertEncoder
(
vocab_size
=
100
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
# Create a dual encoder model with the created network. (Note that all the
# args are different, so we can catch any serialization mismatches.)
dual_encoder_model
=
dual_encoder
.
DualEncoder
(
test_network
,
max_seq_length
=
sequence_length
,
output
=
'predictions'
)
# Create another dual encoder model via serialization and deserialization.
config
=
dual_encoder_model
.
get_config
()
new_dual_encoder
=
dual_encoder
.
DualEncoder
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_dual_encoder
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
dual_encoder_model
.
get_config
(),
new_dual_encoder
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/models/electra_pretrainer.py
View file @
f16a7b5b
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,15 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
copy
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
...
...
@@ -39,7 +36,10 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
Arguments:
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside `__init__` and `call()` implements the computation.
Args:
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
discriminator_network: A transformer network for discriminator, this network
...
...
@@ -47,15 +47,13 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network
for the generator network (not used now)
sequence_length: Input sequence length
last_hidden_dim: Last hidden dim of generator transformer output
num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either
'
logits
'
or
'
predictions
'
.
output_type: The output style for this network. Can be either
`
logits
`
or
`
predictions
`
.
disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence
"""
...
...
@@ -65,8 +63,6 @@ class ElectraPretrainer(tf.keras.Model):
discriminator_network
,
vocab_size
,
num_classes
,
sequence_length
,
last_hidden_dim
,
num_token_predictions
,
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
...
...
@@ -79,8 +75,6 @@ class ElectraPretrainer(tf.keras.Model):
'discriminator_network'
:
discriminator_network
,
'vocab_size'
:
vocab_size
,
'num_classes'
:
num_classes
,
'sequence_length'
:
sequence_length
,
'last_hidden_dim'
:
last_hidden_dim
,
'num_token_predictions'
:
num_token_predictions
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
...
...
@@ -94,8 +88,6 @@ class ElectraPretrainer(tf.keras.Model):
self
.
discriminator_network
=
discriminator_network
self
.
vocab_size
=
vocab_size
self
.
num_classes
=
num_classes
self
.
sequence_length
=
sequence_length
self
.
last_hidden_dim
=
last_hidden_dim
self
.
num_token_predictions
=
num_token_predictions
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
...
...
@@ -108,10 +100,15 @@ class ElectraPretrainer(tf.keras.Model):
output
=
output_type
,
name
=
'generator_masked_lm'
)
self
.
classification
=
layers
.
ClassificationHead
(
inner_dim
=
last_
hidden_
dim
,
inner_dim
=
generator_network
.
get_config
()[
'
hidden_
size'
]
,
num_classes
=
num_classes
,
initializer
=
mlm_initializer
,
name
=
'generator_classification_head'
)
self
.
discriminator_projection
=
tf
.
keras
.
layers
.
Dense
(
units
=
discriminator_network
.
get_config
()[
'hidden_size'
],
activation
=
mlm_activation
,
kernel_initializer
=
mlm_initializer
,
name
=
'discriminator_projection_head'
)
self
.
discriminator_head
=
tf
.
keras
.
layers
.
Dense
(
units
=
1
,
kernel_initializer
=
mlm_initializer
)
...
...
@@ -123,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model):
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs:
a
[batch_size, num_token_predictions, vocab_size]
tensor
indicating logits on masked positions.
(2) sentence_outputs:
a
[batch_size, num_classes] tensor indicating
(1) lm_outputs:
A `
[batch_size, num_token_predictions, vocab_size]
`
tensor
indicating logits on masked positions.
(2) sentence_outputs:
A `
[batch_size, num_classes]
`
tensor indicating
logits for nsp task.
(3) disc_logits:
a
[batch_size, sequence_length] tensor indicating
(3) disc_logits:
A `
[batch_size, sequence_length]
`
tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label:
a
[batch_size, sequence_length] tensor indicating
(4) disc_label:
A `
[batch_size, sequence_length]
`
tensor indicating
target labels for discriminator replaced token detection task.
"""
input_word_ids
=
inputs
[
'input_word_ids'
]
...
...
@@ -138,14 +135,11 @@ class ElectraPretrainer(tf.keras.Model):
masked_lm_positions
=
inputs
[
'masked_lm_positions'
]
### Generator ###
sequence_output
,
cls_output
=
self
.
generator_network
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
sequence_output
=
self
.
generator_network
(
[
input_word_ids
,
input_mask
,
input_type_ids
])[
'sequence_output'
]
# The generator encoder network may get outputs from all layers.
if
isinstance
(
sequence_output
,
list
):
sequence_output
=
sequence_output
[
-
1
]
if
isinstance
(
cls_output
,
list
):
cls_output
=
cls_output
[
-
1
]
lm_outputs
=
self
.
masked_lm
(
sequence_output
,
masked_lm_positions
)
sentence_outputs
=
self
.
classification
(
sequence_output
)
...
...
@@ -156,16 +150,17 @@ class ElectraPretrainer(tf.keras.Model):
### Discriminator ###
disc_input
=
fake_data
[
'inputs'
]
disc_label
=
fake_data
[
'is_fake_tokens'
]
disc_sequence_output
,
_
=
self
.
discriminator_network
([
disc_sequence_output
=
self
.
discriminator_network
([
disc_input
[
'input_word_ids'
],
disc_input
[
'input_mask'
],
disc_input
[
'input_type_ids'
]
])
])
[
'sequence_output'
]
# The discriminator encoder network may get outputs from all layers.
if
isinstance
(
disc_sequence_output
,
list
):
disc_sequence_output
=
disc_sequence_output
[
-
1
]
disc_logits
=
self
.
discriminator_head
(
disc_sequence_output
)
disc_logits
=
self
.
discriminator_head
(
self
.
discriminator_projection
(
disc_sequence_output
))
disc_logits
=
tf
.
squeeze
(
disc_logits
,
axis
=-
1
)
outputs
=
{
...
...
@@ -181,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model):
"""Generate corrupted data for discriminator.
Args:
inputs: A dict of all inputs, same as the input of call() function
inputs: A dict of all inputs, same as the input of
`
call()
`
function
mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications
...
...
@@ -214,6 +209,12 @@ class ElectraPretrainer(tf.keras.Model):
'sampled_tokens'
:
sampled_tokens
}
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
discriminator_network
)
return
items
def
get_config
(
self
):
return
self
.
_config
...
...
@@ -226,16 +227,18 @@ def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence.
Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
updates: A tensor of size batch_size*seq_len(*depth)
positions: A [batch_size, n_positions] tensor
sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
tensor.
updates: A tensor of size `batch_size*seq_len(*depth)`.
positions: A `[batch_size, n_positions]` tensor.
Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth]
tensor of "sequence" with elements at "positions" replaced by the values
at "updates". Updates to index 0 are ignored. If there are duplicated
positions the update is only applied once.
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were
updated_sequence: A `[batch_size, seq_len]` or
`[batch_size, seq_len, depth]` tensor of "sequence" with elements at
"positions" replaced by the values at "updates". Updates to index 0 are
ignored. If there are duplicated positions the update is only
applied once.
updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
updated.
"""
shape
=
tf_utils
.
get_shape_list
(
sequence
,
expected_rank
=
[
2
,
3
])
...
...
@@ -288,14 +291,14 @@ def sample_from_softmax(logits, disallow=None):
"""Implement softmax sampling using gumbel softmax trick.
Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor
indicating
the generator output logits for each masked position.
logits: A
`
[batch_size, num_token_predictions, vocab_size]
`
tensor
indicating
the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size]
this is a tensor of size
`
[batch_size, num_token_predictions, vocab_size]
`
indicating the true word id in each masked position.
Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot
sampled_tokens: A
`
[batch_size, num_token_predictions, vocab_size]
`
one hot
tensor indicating the sampled word id in each masked position.
"""
if
disallow
is
not
None
:
...
...
official/nlp/modeling/models/electra_pretrainer_test.py
View file @
f16a7b5b
# Copyright 202
0
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA pre trainer network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""Tests for ELECTRA pre trainer network."""
import
tensorflow
as
tf
...
...
@@ -35,10 +31,16 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer.
vocab_size
=
100
sequence_length
=
512
test_generator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_discriminator_network
=
networks
.
TransformerEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
sequence_length
=
sequence_length
)
test_generator_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
max_sequence_length
=
sequence_length
,
dict_outputs
=
True
)
test_discriminator_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
2
,
max_sequence_length
=
sequence_length
,
dict_outputs
=
True
)
# Create a ELECTRA trainer with the created network.
num_classes
=
3
...
...
@@ -48,8 +50,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
discriminator_network
=
test_discriminator_network
,
vocab_size
=
vocab_size
,
num_classes
=
num_classes
,
sequence_length
=
sequence_length
,
last_hidden_dim
=
768
,
num_token_predictions
=
num_token_predictions
,
disallow_correct
=
True
)
...
...
@@ -89,10 +89,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.)
test_generator_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
test_discriminator_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
test_generator_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
4
,
max_
sequence_length
=
3
,
dict_outputs
=
True
)
test_discriminator_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
4
,
max_
sequence_length
=
3
,
dict_outputs
=
True
)
# Create a ELECTRA trainer with the created network.
eletrca_trainer_model
=
electra_pretrainer
.
ElectraPretrainer
(
...
...
@@ -101,7 +101,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
num_classes
=
2
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
# Create a set of 2-dimensional data tensors to feed into the model.
...
...
@@ -127,10 +126,10 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the ELECTRA trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_generator_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
test_discriminator_network
=
networks
.
Transform
erEncoder
(
vocab_size
=
100
,
num_layers
=
4
,
sequence_length
=
3
)
test_generator_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
4
,
max_
sequence_length
=
3
)
test_discriminator_network
=
networks
.
B
er
t
Encoder
(
vocab_size
=
100
,
num_layers
=
4
,
max_
sequence_length
=
3
)
# Create a ELECTRA trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
...
...
@@ -140,7 +139,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size
=
100
,
num_classes
=
2
,
sequence_length
=
3
,
last_hidden_dim
=
768
,
num_token_predictions
=
2
)
# Create another BERT trainer via serialization and deserialization.
...
...
official/nlp/modeling/models/seq2seq_transformer.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement Seq2Seq Transformer model by TF official NLP library.
Model paper: https://arxiv.org/pdf/1706.03762.pdf
"""
import
math
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp
import
keras_nlp
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.ops
import
beam_search
EOS_ID
=
1
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Text"
)
class
Seq2SeqTransformer
(
tf
.
keras
.
Model
):
"""Transformer model with Keras.
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
The Transformer model consists of an encoder and decoder. The input is an int
sequence (or a batch of sequences). The encoder produces a continuous
representation, and the decoder uses the encoder output to generate
probabilities for the output sequence.
"""
def
__init__
(
self
,
vocab_size
=
33708
,
embedding_width
=
512
,
dropout_rate
=
0.0
,
padded_decode
=
False
,
decode_max_length
=
None
,
extra_decode_length
=
0
,
beam_size
=
4
,
alpha
=
0.6
,
encoder_layer
=
None
,
decoder_layer
=
None
,
eos_id
=
EOS_ID
,
**
kwargs
):
"""Initialize layers to build Transformer model.
Args:
vocab_size: Size of vocabulary.
embedding_width: Size of hidden layer for embedding.
dropout_rate: Dropout probability.
padded_decode: Whether to max_sequence_length padding is used. If set
False, max_sequence_length padding is not used.
decode_max_length: maximum number of steps to decode a sequence.
extra_decode_length: Beam search will run extra steps to decode.
beam_size: Number of beams for beam search
alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer.
eos_id: Id of end of sentence token.
**kwargs: other keyword arguments.
"""
super
().
__init__
(
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_embedding_width
=
embedding_width
self
.
_dropout_rate
=
dropout_rate
self
.
_padded_decode
=
padded_decode
self
.
_decode_max_length
=
decode_max_length
self
.
_extra_decode_length
=
extra_decode_length
self
.
_beam_size
=
beam_size
self
.
_alpha
=
alpha
self
.
_eos_id
=
eos_id
self
.
embedding_lookup
=
keras_nlp
.
layers
.
OnDeviceEmbedding
(
vocab_size
=
self
.
_vocab_size
,
embedding_width
=
self
.
_embedding_width
,
initializer
=
tf
.
random_normal_initializer
(
mean
=
0.
,
stddev
=
self
.
_embedding_width
**-
0.5
),
scale_factor
=
self
.
_embedding_width
**
0.5
)
self
.
encoder_layer
=
encoder_layer
self
.
decoder_layer
=
decoder_layer
self
.
position_embedding
=
layers
.
RelativePositionEmbedding
(
hidden_size
=
self
.
_embedding_width
)
self
.
encoder_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
decoder_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
def
get_config
(
self
):
config
=
{
"vocab_size"
:
self
.
_vocab_size
,
"hidden_size"
:
self
.
_embedding_width
,
"dropout_rate"
:
self
.
_dropout_rate
,
"padded_decode"
:
self
.
_padded_decode
,
"decode_max_length"
:
self
.
_decode_max_length
,
"eos_id"
:
self
.
_eos_id
,
"extra_decode_length"
:
self
.
_extra_decode_length
,
"beam_size"
:
self
.
_beam_size
,
"alpha"
:
self
.
_alpha
,
"encoder_layer"
:
self
.
encoder_layer
,
"decoder_layer"
:
self
.
decoder_layer
}
base_config
=
super
(
Seq2SeqTransformer
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
_embedding_linear
(
self
,
embedding_matrix
,
x
):
"""Uses embeddings as linear transformation weights."""
batch_size
=
tf
.
shape
(
x
)[
0
]
length
=
tf
.
shape
(
x
)[
1
]
hidden_size
=
tf
.
shape
(
x
)[
2
]
vocab_size
=
tf
.
shape
(
embedding_matrix
)[
0
]
x
=
tf
.
reshape
(
x
,
[
-
1
,
hidden_size
])
logits
=
tf
.
matmul
(
x
,
tf
.
cast
(
embedding_matrix
,
x
.
dtype
),
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
vocab_size
])
def
call
(
self
,
inputs
):
"""Calculate target logits or inferred target sequences.
Args:
inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`.
Returns:
If targets is defined, then return logits for each word in the target
sequence, which is a float tensor with shape
`(batch_size, target_length, vocab_size)`. If target is `None`, then
generate output sequence one token at a time and
returns a dictionary {
outputs: `(batch_size, decoded_length)`
scores: `(batch_size, 1)`}
Even when `float16` is used, the output tensor(s) are always `float32`.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
sources
=
inputs
[
"inputs"
]
targets
=
inputs
.
get
(
"targets"
,
None
)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs
=
self
.
embedding_lookup
(
sources
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
sources
,
0
),
embedded_inputs
.
dtype
)
embedded_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
# Attention_mask generation.
input_shape
=
tf_utils
.
get_shape_list
(
sources
,
expected_rank
=
2
)
attention_mask
=
tf
.
cast
(
tf
.
reshape
(
tf
.
not_equal
(
sources
,
0
),
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
dtype
=
sources
.
dtype
)
broadcast_ones
=
tf
.
ones
(
shape
=
[
input_shape
[
0
],
input_shape
[
1
],
1
],
dtype
=
sources
.
dtype
)
attention_mask
=
broadcast_ones
*
attention_mask
pos_encoding
=
self
.
position_embedding
(
embedded_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
embedded_inputs
.
dtype
)
encoder_inputs
=
embedded_inputs
+
pos_encoding
encoder_inputs
=
self
.
encoder_dropout
(
encoder_inputs
)
encoder_outputs
=
self
.
encoder_layer
(
encoder_inputs
,
attention_mask
=
attention_mask
)
if
targets
is
None
:
if
self
.
_padded_decode
:
max_decode_length
=
self
.
_decode_max_length
else
:
max_decode_length
=
self
.
_decode_max_length
or
(
tf
.
shape
(
encoder_outputs
)[
1
]
+
self
.
_extra_decode_length
)
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
max_decode_length
)
batch_size
=
tf
.
shape
(
encoder_outputs
)[
0
]
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids
=
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
int32
)
# Create cache storing decoder attention values for each layer.
init_decode_length
=
(
max_decode_length
if
self
.
_padded_decode
else
0
)
num_heads
=
self
.
decoder_layer
.
num_attention_heads
dim_per_head
=
self
.
_embedding_width
//
num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache
=
{
str
(
layer
):
{
"key"
:
tf
.
zeros
(
[
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
compute_dtype
),
"value"
:
tf
.
zeros
(
[
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
dtype
=
self
.
compute_dtype
)
}
for
layer
in
range
(
self
.
decoder_layer
.
num_layers
)
}
# pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache.
encoder_outputs
=
tf
.
cast
(
encoder_outputs
,
dtype
=
self
.
compute_dtype
)
attention_mask
=
tf
.
cast
(
tf
.
reshape
(
tf
.
not_equal
(
sources
,
0
),
[
input_shape
[
0
],
1
,
input_shape
[
1
]]),
dtype
=
self
.
compute_dtype
)
cache
[
"encoder_outputs"
]
=
encoder_outputs
cache
[
"encoder_decoder_attention_mask"
]
=
attention_mask
# Use beam search to find the top beam_size sequences and scores.
decoded_ids
,
scores
=
beam_search
.
sequence_beam_search
(
symbols_to_logits_fn
=
symbols_to_logits_fn
,
initial_ids
=
initial_ids
,
initial_cache
=
cache
,
vocab_size
=
self
.
_vocab_size
,
beam_size
=
self
.
_beam_size
,
alpha
=
self
.
_alpha
,
max_decode_length
=
max_decode_length
,
eos_id
=
self
.
_eos_id
,
padded_decode
=
self
.
_padded_decode
,
dtype
=
self
.
compute_dtype
)
# Get the top sequence for each batch element
top_decoded_ids
=
decoded_ids
[:,
0
,
1
:]
top_scores
=
scores
[:,
0
]
return
{
"outputs"
:
top_decoded_ids
,
"scores"
:
top_scores
}
decoder_inputs
=
self
.
embedding_lookup
(
targets
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
targets
,
0
),
decoder_inputs
.
dtype
)
decoder_inputs
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
# Shift targets to the right, and remove the last element
decoder_inputs
=
tf
.
pad
(
decoder_inputs
,
[[
0
,
0
],
[
1
,
0
],
[
0
,
0
]])[:,
:
-
1
,
:]
length
=
tf
.
shape
(
decoder_inputs
)[
1
]
pos_encoding
=
self
.
position_embedding
(
decoder_inputs
)
pos_encoding
=
tf
.
cast
(
pos_encoding
,
embedded_inputs
.
dtype
)
decoder_inputs
+=
pos_encoding
decoder_inputs
=
self
.
decoder_dropout
(
decoder_inputs
)
decoder_shape
=
tf_utils
.
get_shape_list
(
decoder_inputs
,
expected_rank
=
3
)
batch_size
=
decoder_shape
[
0
]
decoder_length
=
decoder_shape
[
1
]
self_attention_mask
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
length
,
length
]),
-
1
,
0
)
self_attention_mask
=
tf
.
reshape
(
self_attention_mask
,
[
1
,
length
,
length
])
self_attention_mask
=
tf
.
tile
(
self_attention_mask
,
[
batch_size
,
1
,
1
])
attention_mask
=
tf
.
cast
(
tf
.
expand_dims
(
tf
.
not_equal
(
sources
,
0
),
axis
=
1
),
dtype
=
sources
.
dtype
)
attention_mask
=
tf
.
tile
(
attention_mask
,
[
1
,
decoder_length
,
1
])
outputs
=
self
.
decoder_layer
(
decoder_inputs
,
encoder_outputs
,
self_attention_mask
=
self_attention_mask
,
cross_attention_mask
=
attention_mask
)
logits
=
self
.
_embedding_linear
(
self
.
embedding_lookup
.
embeddings
,
outputs
)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits
=
tf
.
cast
(
logits
,
tf
.
float32
)
return
logits
def
_get_symbols_to_logits_fn
(
self
,
max_decode_length
):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal
=
self
.
position_embedding
(
inputs
=
None
,
length
=
max_decode_length
+
1
)
timing_signal
=
tf
.
cast
(
timing_signal
,
dtype
=
self
.
compute_dtype
)
decoder_self_attention_mask
=
tf
.
linalg
.
band_part
(
tf
.
ones
([
max_decode_length
,
max_decode_length
],
dtype
=
self
.
compute_dtype
),
-
1
,
0
)
decoder_self_attention_mask
=
tf
.
reshape
(
decoder_self_attention_mask
,
[
1
,
max_decode_length
,
max_decode_length
])
def
symbols_to_logits_fn
(
ids
,
i
,
cache
):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape `(batch_size *
beam_size, i + 1)`.
i: Loop index.
cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
Returns:
Tuple of
(logits with shape `(batch_size * beam_size, vocab_size)`,
updated cache values)
"""
# Set decoder input to the last generated IDs
decoder_input
=
ids
[:,
-
1
:]
# Preprocess decoder input by getting embeddings and adding timing signal.
# decoder_input = self.embedding_softmax_layer(decoder_input)
source_decoder_input
=
decoder_input
decoder_input
=
self
.
embedding_lookup
(
decoder_input
)
embedding_mask
=
tf
.
cast
(
tf
.
not_equal
(
source_decoder_input
,
0
),
decoder_input
.
dtype
)
decoder_input
*=
tf
.
expand_dims
(
embedding_mask
,
-
1
)
decoder_input
+=
timing_signal
[
i
]
if
self
.
_padded_decode
:
# indexing does not work on TPU.
bias_shape
=
decoder_self_attention_mask
.
shape
.
as_list
()
self_attention_mask
=
tf
.
slice
(
decoder_self_attention_mask
,
[
0
,
i
,
0
],
[
bias_shape
[
0
],
1
,
bias_shape
[
2
]])
else
:
self_attention_mask
=
decoder_self_attention_mask
[:,
i
:
i
+
1
,
:
i
+
1
]
decoder_shape
=
tf_utils
.
get_shape_list
(
decoder_input
,
expected_rank
=
3
)
batch_size
=
decoder_shape
[
0
]
decoder_length
=
decoder_shape
[
1
]
self_attention_mask
=
tf
.
tile
(
self_attention_mask
,
[
batch_size
,
1
,
1
])
attention_mask
=
cache
.
get
(
"encoder_decoder_attention_mask"
)
attention_mask
=
tf
.
tile
(
attention_mask
,
[
1
,
decoder_length
,
1
])
decoder_outputs
=
self
.
decoder_layer
(
decoder_input
,
cache
.
get
(
"encoder_outputs"
),
self_attention_mask
=
self_attention_mask
,
cross_attention_mask
=
attention_mask
,
cache
=
cache
,
decode_loop_step
=
i
if
self
.
_padded_decode
else
None
)
decoder_outputs
=
tf
.
cast
(
decoder_outputs
,
dtype
=
self
.
compute_dtype
)
logits
=
self
.
_embedding_linear
(
self
.
embedding_lookup
.
embeddings
,
decoder_outputs
)
logits
=
tf
.
squeeze
(
logits
,
axis
=
[
1
])
return
logits
,
cache
return
symbols_to_logits_fn
class
TransformerEncoder
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer encoder.
Transformer encoder is made up of N identical layers. Each layer is composed
of the sublayers:
1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers)
"""
def
__init__
(
self
,
num_layers
=
6
,
num_attention_heads
=
8
,
intermediate_size
=
2048
,
activation
=
"relu"
,
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.0
,
**
kwargs
):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super
(
TransformerEncoder
,
self
).
__init__
(
**
kwargs
)
self
.
num_layers
=
num_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
_intermediate_size
=
intermediate_size
self
.
_activation
=
activation
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_intermediate_dropout
=
intermediate_dropout
def
build
(
self
,
input_shape
):
"""Implements build() for the layer."""
self
.
encoder_layers
=
[]
for
i
in
range
(
self
.
num_layers
):
self
.
encoder_layers
.
append
(
keras_nlp
.
layers
.
TransformerEncoderBlock
(
num_attention_heads
=
self
.
num_attention_heads
,
inner_dim
=
self
.
_intermediate_size
,
inner_activation
=
self
.
_activation
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
norm_first
=
self
.
_norm_first
,
norm_epsilon
=
self
.
_norm_epsilon
,
inner_dropout
=
self
.
_intermediate_dropout
,
attention_initializer
=
attention_initializer
(
input_shape
[
2
]),
name
=
(
"layer_%d"
%
i
)))
self
.
output_normalization
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
self
.
_norm_epsilon
,
dtype
=
"float32"
)
super
(
TransformerEncoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_layers"
:
self
.
num_layers
,
"num_attention_heads"
:
self
.
num_attention_heads
,
"intermediate_size"
:
self
.
_intermediate_size
,
"activation"
:
self
.
_activation
,
"dropout_rate"
:
self
.
_dropout_rate
,
"attention_dropout_rate"
:
self
.
_attention_dropout_rate
,
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"intermediate_dropout"
:
self
.
_intermediate_dropout
}
base_config
=
super
(
TransformerEncoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
encoder_inputs
,
attention_mask
=
None
):
"""Return the output of the encoder.
Args:
encoder_inputs: A tensor with shape `(batch_size, input_length,
hidden_size)`.
attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
Returns:
Output of encoder which is a `float32` tensor with shape
`(batch_size, input_length, hidden_size)`.
"""
for
layer_idx
in
range
(
self
.
num_layers
):
encoder_inputs
=
self
.
encoder_layers
[
layer_idx
](
[
encoder_inputs
,
attention_mask
])
output_tensor
=
encoder_inputs
output_tensor
=
self
.
output_normalization
(
output_tensor
)
return
output_tensor
class
TransformerDecoder
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer decoder.
Like the encoder, the decoder is made up of N identical layers.
Each layer is composed of the sublayers:
1. Self-attention layer
2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer.
3. Feedforward network (2 fully-connected layers)
"""
def
__init__
(
self
,
num_layers
=
6
,
num_attention_heads
=
8
,
intermediate_size
=
2048
,
activation
=
"relu"
,
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.0
,
**
kwargs
):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super
(
TransformerDecoder
,
self
).
__init__
(
**
kwargs
)
self
.
num_layers
=
num_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
_intermediate_size
=
intermediate_size
self
.
_activation
=
activation
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_intermediate_dropout
=
intermediate_dropout
def
build
(
self
,
input_shape
):
"""Implements build() for the layer."""
self
.
decoder_layers
=
[]
for
i
in
range
(
self
.
num_layers
):
self
.
decoder_layers
.
append
(
layers
.
TransformerDecoderBlock
(
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
_intermediate_size
,
intermediate_activation
=
self
.
_activation
,
dropout_rate
=
self
.
_dropout_rate
,
attention_dropout_rate
=
self
.
_attention_dropout_rate
,
use_bias
=
self
.
_use_bias
,
norm_first
=
self
.
_norm_first
,
norm_epsilon
=
self
.
_norm_epsilon
,
intermediate_dropout
=
self
.
_intermediate_dropout
,
attention_initializer
=
attention_initializer
(
input_shape
[
2
]),
name
=
(
"layer_%d"
%
i
)))
self
.
output_normalization
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
1e-6
,
dtype
=
"float32"
)
super
(
TransformerDecoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_layers"
:
self
.
num_layers
,
"num_attention_heads"
:
self
.
num_attention_heads
,
"intermediate_size"
:
self
.
_intermediate_size
,
"activation"
:
self
.
_activation
,
"dropout_rate"
:
self
.
_dropout_rate
,
"attention_dropout_rate"
:
self
.
_attention_dropout_rate
,
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"intermediate_dropout"
:
self
.
_intermediate_dropout
}
base_config
=
super
(
TransformerDecoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
target
,
memory
,
self_attention_mask
=
None
,
cross_attention_mask
=
None
,
cache
=
None
,
decode_loop_step
=
None
):
"""Return the output of the decoder layer stacks.
Args:
target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
self_attention_mask: A tensor with shape `(batch_size, target_len,
target_length)`, the mask for decoder self-attention layer.
cross_attention_mask: A tensor with shape `(batch_size, target_length,
input_length)` which is the mask for encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape `(batch_size, i, value_channels)`},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Output of decoder.
float32 tensor with shape `(batch_size, target_length, hidden_size`).
"""
output_tensor
=
target
for
layer_idx
in
range
(
self
.
num_layers
):
transformer_inputs
=
[
output_tensor
,
memory
,
cross_attention_mask
,
self_attention_mask
]
# Gets the cache for decoding.
if
cache
is
None
:
output_tensor
,
_
=
self
.
decoder_layers
[
layer_idx
](
transformer_inputs
)
else
:
cache_layer_idx
=
str
(
layer_idx
)
output_tensor
,
cache
[
cache_layer_idx
]
=
self
.
decoder_layers
[
layer_idx
](
transformer_inputs
,
cache
=
cache
[
cache_layer_idx
],
decode_loop_step
=
decode_loop_step
)
return
self
.
output_normalization
(
output_tensor
)
def
attention_initializer
(
hidden_size
):
"""Initializer for attention layers in Seq2SeqTransformer."""
hidden_size
=
int
(
hidden_size
)
limit
=
math
.
sqrt
(
6.0
/
(
hidden_size
+
hidden_size
))
return
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=-
limit
,
maxval
=
limit
)
official/nlp/modeling/models/seq2seq_transformer_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Transformer model."""
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.nlp.modeling.models
import
seq2seq_transformer
class
Seq2SeqTransformerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
_build_model
(
self
,
padded_decode
,
decode_max_length
):
num_layers
=
1
num_attention_heads
=
2
intermediate_size
=
32
vocab_size
=
100
embedding_width
=
16
encdec_kwargs
=
dict
(
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
intermediate_size
,
activation
=
"relu"
,
dropout_rate
=
0.01
,
attention_dropout_rate
=
0.01
,
use_bias
=
False
,
norm_first
=
True
,
norm_epsilon
=
1e-6
,
intermediate_dropout
=
0.01
)
encoder_layer
=
seq2seq_transformer
.
TransformerEncoder
(
**
encdec_kwargs
)
decoder_layer
=
seq2seq_transformer
.
TransformerDecoder
(
**
encdec_kwargs
)
return
seq2seq_transformer
.
Seq2SeqTransformer
(
vocab_size
=
vocab_size
,
embedding_width
=
embedding_width
,
dropout_rate
=
0.01
,
padded_decode
=
padded_decode
,
decode_max_length
=
decode_max_length
,
beam_size
=
4
,
alpha
=
0.6
,
encoder_layer
=
encoder_layer
,
decoder_layer
=
decoder_layer
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
],
mode
=
"eager"
))
def
test_create_model_with_ds
(
self
,
distribution
):
with
distribution
.
scope
():
padded_decode
=
isinstance
(
distribution
,
(
tf
.
distribute
.
TPUStrategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
))
decode_max_length
=
10
batch_size
=
4
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
)
@
tf
.
function
def
step
(
inputs
):
def
_step_fn
(
inputs
):
return
model
(
inputs
)
outputs
=
distribution
.
run
(
_step_fn
,
args
=
(
inputs
,))
return
tf
.
nest
.
map_structure
(
distribution
.
experimental_local_results
,
outputs
)
fake_inputs
=
dict
(
inputs
=
np
.
zeros
((
batch_size
,
decode_max_length
),
dtype
=
np
.
int32
))
local_outputs
=
step
(
fake_inputs
)
logging
.
info
(
"local_outputs=%s"
,
local_outputs
)
self
.
assertEqual
(
local_outputs
[
"outputs"
][
0
].
shape
,
(
4
,
10
))
fake_inputs
=
dict
(
inputs
=
np
.
zeros
((
batch_size
,
decode_max_length
),
dtype
=
np
.
int32
),
targets
=
np
.
zeros
((
batch_size
,
8
),
dtype
=
np
.
int32
))
local_outputs
=
step
(
fake_inputs
)
logging
.
info
(
"local_outputs=%s"
,
local_outputs
)
self
.
assertEqual
(
local_outputs
[
0
].
shape
,
(
4
,
8
,
100
))
@
parameterized
.
parameters
(
True
,
False
)
def
test_create_savedmodel
(
self
,
padded_decode
):
decode_max_length
=
10
model
=
self
.
_build_model
(
padded_decode
,
decode_max_length
)
class
SaveModule
(
tf
.
Module
):
def
__init__
(
self
,
model
):
super
(
SaveModule
,
self
).
__init__
()
self
.
model
=
model
@
tf
.
function
def
serve
(
self
,
inputs
):
return
self
.
model
.
call
(
dict
(
inputs
=
inputs
))
save_module
=
SaveModule
(
model
)
if
padded_decode
:
tensor_shape
=
(
4
,
10
)
else
:
tensor_shape
=
(
None
,
None
)
signatures
=
dict
(
serving_default
=
save_module
.
serve
.
get_concrete_function
(
tf
.
TensorSpec
(
shape
=
tensor_shape
,
dtype
=
tf
.
int32
,
name
=
"inputs"
)))
tf
.
saved_model
.
save
(
save_module
,
self
.
get_temp_dir
(),
signatures
=
signatures
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/modeling/models/xlnet.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""XLNet models."""
# pylint: disable=g-classes-have-attributes
from
typing
import
Any
,
Mapping
,
Union
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
class
XLNetMaskedLM
(
tf
.
keras
.
layers
.
Layer
):
"""XLNet pretraining head."""
def
__init__
(
self
,
vocab_size
:
int
,
hidden_size
:
int
,
initializer
:
str
=
'glorot_uniform'
,
activation
:
str
=
'gelu'
,
name
=
None
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_vocab_size
=
vocab_size
self
.
_hidden_size
=
hidden_size
self
.
_initializer
=
initializer
self
.
_activation
=
activation
def
build
(
self
,
input_shape
):
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
_hidden_size
,
activation
=
self
.
_activation
,
kernel_initializer
=
self
.
_initializer
,
name
=
'transform/dense'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'transform/LayerNorm'
)
self
.
bias
=
self
.
add_weight
(
'output_bias/bias'
,
shape
=
(
self
.
_vocab_size
,),
initializer
=
'zeros'
,
trainable
=
True
)
super
().
build
(
input_shape
)
def
call
(
self
,
sequence_data
:
tf
.
Tensor
,
embedding_table
:
tf
.
Tensor
):
lm_data
=
self
.
dense
(
sequence_data
)
lm_data
=
self
.
layer_norm
(
lm_data
)
lm_data
=
tf
.
matmul
(
lm_data
,
embedding_table
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
lm_data
,
self
.
bias
)
return
logits
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
config
=
{
'vocab_size'
:
self
.
_vocab_size
,
'hidden_size'
:
self
.
_hidden_size
,
'initializer'
:
self
.
_initializer
}
base_config
=
super
(
XLNetMaskedLM
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetPretrainer
(
tf
.
keras
.
Model
):
"""XLNet-based pretrainer.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Args:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
mlm_activation: The activation (if any) to use in the Masked LM network. If
None, then no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Defaults
to a Glorot uniform initializer.
"""
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
mlm_activation
=
None
,
mlm_initializer
=
'glorot_uniform'
,
name
:
str
=
None
,
**
kwargs
):
super
().
__init__
(
name
=
name
,
**
kwargs
)
self
.
_config
=
{
'network'
:
network
,
'mlm_activation'
:
mlm_activation
,
'mlm_initializer'
:
mlm_initializer
,
}
self
.
_network
=
network
self
.
_hidden_size
=
network
.
get_config
()[
'hidden_size'
]
self
.
_vocab_size
=
network
.
get_config
()[
'vocab_size'
]
self
.
_activation
=
mlm_activation
self
.
_initializer
=
mlm_initializer
self
.
_masked_lm
=
XLNetMaskedLM
(
vocab_size
=
self
.
_vocab_size
,
hidden_size
=
self
.
_hidden_size
,
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_word_ids
=
inputs
[
'input_word_ids'
]
input_type_ids
=
inputs
[
'input_type_ids'
]
masked_tokens
=
inputs
[
'masked_tokens'
]
permutation_mask
=
inputs
[
'permutation_mask'
]
target_mapping
=
inputs
[
'target_mapping'
]
state
=
inputs
.
get
(
'state'
,
None
)
attention_output
,
state
=
self
.
_network
(
input_ids
=
input_word_ids
,
segment_ids
=
input_type_ids
,
input_mask
=
None
,
state
=
state
,
permutation_mask
=
permutation_mask
,
target_mapping
=
target_mapping
,
masked_tokens
=
masked_tokens
)
embedding_table
=
self
.
_network
.
get_embedding_lookup_table
()
mlm_outputs
=
self
.
_masked_lm
(
sequence_data
=
attention_output
,
embedding_table
=
embedding_table
)
return
mlm_outputs
,
state
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetClassifier
(
tf
.
keras
.
Model
):
"""Classifier model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Note: This model does not use utilize the memory mechanism used in the
original XLNet Classifier.
Args:
network: An XLNet/Transformer-XL based network. This network should output a
sequence output and list of `state` tensors.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
"""
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
num_classes
:
int
,
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'random_normal'
,
summary_type
:
str
=
'last'
,
dropout_rate
:
float
=
0.1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_network
=
network
self
.
_initializer
=
initializer
self
.
_summary_type
=
summary_type
self
.
_num_classes
=
num_classes
self
.
_config
=
{
'network'
:
network
,
'initializer'
:
initializer
,
'num_classes'
:
num_classes
,
'summary_type'
:
summary_type
,
'dropout_rate'
:
dropout_rate
,
}
if
summary_type
==
'last'
:
cls_token_idx
=
-
1
elif
summary_type
==
'first'
:
cls_token_idx
=
0
else
:
raise
ValueError
(
'Invalid summary type provided: %s.'
%
summary_type
)
self
.
classifier
=
layers
.
ClassificationHead
(
inner_dim
=
network
.
get_config
()[
'hidden_size'
],
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
cls_token_idx
=
cls_token_idx
,
name
=
'sentence_prediction'
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_word_ids'
]
segment_ids
=
inputs
[
'input_type_ids'
]
input_mask
=
tf
.
cast
(
inputs
[
'input_mask'
],
tf
.
float32
)
state
=
inputs
.
get
(
'mems'
,
None
)
attention_output
,
_
=
self
.
_network
(
input_ids
=
input_ids
,
segment_ids
=
segment_ids
,
input_mask
=
input_mask
,
state
=
state
)
logits
=
self
.
classifier
(
attention_output
)
return
logits
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
items
=
dict
(
encoder
=
self
.
_network
)
if
hasattr
(
self
.
classifier
,
'checkpoint_items'
):
for
key
,
item
in
self
.
classifier
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
self
.
classifier
.
name
,
key
])]
=
item
return
items
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetSpanLabeler
(
tf
.
keras
.
Model
):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Args:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation: The activation for the span labeling head.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
start_n_top
:
int
=
5
,
end_n_top
:
int
=
5
,
dropout_rate
:
float
=
0.1
,
span_labeling_activation
:
tf
.
keras
.
initializers
.
Initializer
=
'tanh'
,
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'glorot_uniform'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_config
=
{
'network'
:
network
,
'start_n_top'
:
start_n_top
,
'end_n_top'
:
end_n_top
,
'dropout_rate'
:
dropout_rate
,
'span_labeling_activation'
:
span_labeling_activation
,
'initializer'
:
initializer
,
}
network_config
=
network
.
get_config
()
try
:
input_width
=
network_config
[
'inner_size'
]
self
.
_xlnet_base
=
True
except
KeyError
:
# BertEncoder uses 'intermediate_size' due to legacy naming.
input_width
=
network_config
[
'intermediate_size'
]
self
.
_xlnet_base
=
False
self
.
_network
=
network
self
.
_initializer
=
initializer
self
.
_start_n_top
=
start_n_top
self
.
_end_n_top
=
end_n_top
self
.
_dropout_rate
=
dropout_rate
self
.
_activation
=
span_labeling_activation
self
.
span_labeling
=
networks
.
XLNetSpanLabeling
(
input_width
=
input_width
,
start_n_top
=
self
.
_start_n_top
,
end_n_top
=
self
.
_end_n_top
,
activation
=
self
.
_activation
,
dropout_rate
=
self
.
_dropout_rate
,
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_word_ids
=
inputs
[
'input_word_ids'
]
input_type_ids
=
inputs
[
'input_type_ids'
]
input_mask
=
inputs
[
'input_mask'
]
class_index
=
inputs
[
'class_index'
]
paragraph_mask
=
inputs
[
'paragraph_mask'
]
start_positions
=
inputs
.
get
(
'start_positions'
,
None
)
if
self
.
_xlnet_base
:
attention_output
,
_
=
self
.
_network
(
input_ids
=
input_word_ids
,
segment_ids
=
input_type_ids
,
input_mask
=
input_mask
)
else
:
network_output_dict
=
self
.
_network
(
dict
(
input_word_ids
=
input_word_ids
,
input_type_ids
=
input_type_ids
,
input_mask
=
input_mask
))
attention_output
=
network_output_dict
[
'sequence_output'
]
outputs
=
self
.
span_labeling
(
sequence_data
=
attention_output
,
class_index
=
class_index
,
paragraph_mask
=
paragraph_mask
,
start_positions
=
start_positions
)
return
outputs
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment