"vscode:/vscode.git/clone" did not exist on "864ed34f565ea5c066778c3c1aa708903ec22be4"
Commit e3a74e5b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Rename TransformerEncoder to BertEncoder.

PiperOrigin-RevId: 328829120
parent 28cbb02d
...@@ -118,7 +118,7 @@ def get_transformer_encoder(bert_config, ...@@ -118,7 +118,7 @@ def get_transformer_encoder(bert_config,
is to return the entire sequence output. is to return the entire sequence output.
Returns: Returns:
A networks.TransformerEncoder object. A encoder object.
""" """
del sequence_length del sequence_length
if transformer_encoder_cls is not None: if transformer_encoder_cls is not None:
...@@ -171,7 +171,7 @@ def get_transformer_encoder(bert_config, ...@@ -171,7 +171,7 @@ def get_transformer_encoder(bert_config,
else: else:
assert isinstance(bert_config, configs.BertConfig) assert isinstance(bert_config, configs.BertConfig)
kwargs['output_range'] = output_range kwargs['output_range'] = output_range
return networks.TransformerEncoder(**kwargs) return networks.BertEncoder(**kwargs)
def pretrain_model(bert_config, def pretrain_model(bert_config,
......
...@@ -48,7 +48,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -48,7 +48,7 @@ class BertModelsTest(tf.test.TestCase):
initializer=None, initializer=None,
use_next_sentence_label=True) use_next_sentence_label=True)
self.assertIsInstance(model, tf.keras.Model) self.assertIsInstance(model, tf.keras.Model)
self.assertIsInstance(encoder, networks.TransformerEncoder) self.assertIsInstance(encoder, networks.BertEncoder)
# model has one scalar output: loss value. # model has one scalar output: loss value.
self.assertEqual(model.output.shape.as_list(), [ self.assertEqual(model.output.shape.as_list(), [
......
...@@ -57,7 +57,7 @@ def _create_bert_model(cfg): ...@@ -57,7 +57,7 @@ def _create_bert_model(cfg):
Returns: Returns:
A TransformerEncoder netowork. A TransformerEncoder netowork.
""" """
bert_encoder = networks.TransformerEncoder( bert_encoder = networks.BertEncoder(
vocab_size=cfg.vocab_size, vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_size, hidden_size=cfg.hidden_size,
num_layers=cfg.num_hidden_layers, num_layers=cfg.num_hidden_layers,
......
...@@ -130,7 +130,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -130,7 +130,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
ENCODER_CLS = { ENCODER_CLS = {
"bert": networks.TransformerEncoder, "bert": networks.BertEncoder,
"mobilebert": networks.MobileBERTEncoder, "mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertTransformerEncoder, "albert": networks.AlbertTransformerEncoder,
} }
......
...@@ -20,7 +20,7 @@ import tensorflow as tf ...@@ -20,7 +20,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.layers import masked_lm from official.nlp.modeling.layers import masked_lm
from official.nlp.modeling.networks import transformer_encoder from official.nlp.modeling.networks import bert_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
...@@ -36,7 +36,7 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -36,7 +36,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
# First, create a transformer stack that we can use to get the LM's # First, create a transformer stack that we can use to get the LM's
# vocabulary weight. # vocabulary weight.
if xformer_stack is None: if xformer_stack is None:
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = bert_encoder.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -69,7 +69,7 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -69,7 +69,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
sequence_length = 32 sequence_length = 32
hidden_size = 64 hidden_size = 64
num_predictions = 21 num_predictions = 21
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = bert_encoder.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
hidden_size=hidden_size, hidden_size=hidden_size,
......
...@@ -39,7 +39,7 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -39,7 +39,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output="predictions"): output="predictions"):
# First, create a transformer stack that we can use to get the LM's # First, create a transformer stack that we can use to get the LM's
# vocabulary weight. # vocabulary weight.
xformer_stack = networks.TransformerEncoder( xformer_stack = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length, sequence_length=sequence_length,
......
...@@ -37,8 +37,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -37,8 +37,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -61,7 +60,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -61,7 +60,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2) test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -81,7 +80,8 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -81,7 +80,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2) test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -35,7 +35,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -35,7 +35,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length)
...@@ -70,7 +70,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -70,7 +70,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2, sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
...@@ -92,7 +92,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -92,7 +92,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, max_sequence_length=5) vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
...@@ -116,7 +116,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -116,7 +116,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length)
...@@ -142,7 +142,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -142,7 +142,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
......
...@@ -35,8 +35,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -35,8 +35,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -59,8 +58,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -59,8 +58,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate compilation using explicit output names.""" """Validate compilation using explicit output names."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2)
vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -79,7 +77,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -79,7 +77,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2) test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -98,7 +96,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -98,7 +96,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(vocab_size=100, num_layers=2) test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -35,7 +35,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -35,7 +35,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length)
...@@ -62,7 +62,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, max_sequence_length=2) vocab_size=100, num_layers=2, max_sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
...@@ -83,7 +83,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -83,7 +83,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
"""Validate that the BERT trainer can be serialized and deserialized.""" """Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, max_sequence_length=5) vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
......
...@@ -38,8 +38,10 @@ class DualEncoderTest(keras_parameterized.TestCase): ...@@ -38,8 +38,10 @@ class DualEncoderTest(keras_parameterized.TestCase):
# Build a transformer network to use within the dual encoder model. # Build a transformer network to use within the dual encoder model.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, hidden_size=hidden_size, vocab_size=vocab_size,
num_layers=2,
hidden_size=hidden_size,
sequence_length=sequence_length) sequence_length=sequence_length)
# Create a dual encoder model with the created network. # Create a dual encoder model with the created network.
...@@ -75,7 +77,7 @@ class DualEncoderTest(keras_parameterized.TestCase): ...@@ -75,7 +77,7 @@ class DualEncoderTest(keras_parameterized.TestCase):
# Build a transformer network to use within the dual encoder model. (Here, # Build a transformer network to use within the dual encoder model. (Here,
# we use # a short sequence_length for convenience.) # we use # a short sequence_length for convenience.)
sequence_length = 2 sequence_length = 2
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length) vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network. # Create a dual encoder model with the created network.
...@@ -101,7 +103,7 @@ class DualEncoderTest(keras_parameterized.TestCase): ...@@ -101,7 +103,7 @@ class DualEncoderTest(keras_parameterized.TestCase):
# Build a transformer network to use within the dual encoder model. (Here, # Build a transformer network to use within the dual encoder model. (Here,
# we use a short sequence_length for convenience.) # we use a short sequence_length for convenience.)
sequence_length = 32 sequence_length = 32
test_network = networks.TransformerEncoder( test_network = networks.BertEncoder(
vocab_size=100, num_layers=2, sequence_length=sequence_length) vocab_size=100, num_layers=2, sequence_length=sequence_length)
# Create a dual encoder model with the created network. (Note that all the # Create a dual encoder model with the created network. (Note that all the
......
...@@ -35,11 +35,11 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -35,11 +35,11 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer. # Build a transformer network to use within the ELECTRA trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length)
...@@ -91,9 +91,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -91,9 +91,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the ELECTRA trainer. (Here, we # Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.) # use a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
...@@ -128,9 +128,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -128,9 +128,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
"""Validate that the ELECTRA trainer can be serialized and deserialized.""" """Validate that the ELECTRA trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.BertEncoder(
vocab_size=100, num_layers=4, max_sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. (Note that all the args # Create a ELECTRA trainer with the created network. (Note that all the args
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
# ============================================================================== # ==============================================================================
"""Networks package definition.""" """Networks package definition."""
from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTransformerEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoder
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder from official.nlp.modeling.networks.mobile_bert_encoder import MobileBERTEncoder
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder # Backward compatibility. The modules are deprecated.
TransformerEncoder = BertEncoder
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
# ============================================================================== # ==============================================================================
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
...@@ -26,7 +22,7 @@ from official.nlp.modeling import layers ...@@ -26,7 +22,7 @@ from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class TransformerEncoder(tf.keras.Model): class BertEncoder(tf.keras.Model):
"""Bi-directional Transformer-based encoder network. """Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as This network implements a bi-directional Transformer-based encoder as
...@@ -207,7 +203,7 @@ class TransformerEncoder(tf.keras.Model): ...@@ -207,7 +203,7 @@ class TransformerEncoder(tf.keras.Model):
else: else:
outputs = [encoder_outputs[-1], cls_output] outputs = [encoder_outputs[-1], cls_output]
super(TransformerEncoder, self).__init__( super(BertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
def get_embedding_table(self): def get_embedding_table(self):
......
...@@ -18,28 +18,29 @@ from __future__ import absolute_import ...@@ -18,28 +18,29 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Import libraries
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import transformer_encoder from official.nlp.modeling.networks import bert_encoder
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover. # guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class TransformerEncoderTest(keras_parameterized.TestCase): class BertEncoderTest(keras_parameterized.TestCase):
def tearDown(self): def tearDown(self):
super(TransformerEncoderTest, self).tearDown() super(BertEncoderTest, self).tearDown()
tf.keras.mixed_precision.experimental.set_policy("float32") tf.keras.mixed_precision.experimental.set_policy("float32")
def test_network_creation(self): def test_network_creation(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
# Create a small TransformerEncoder for testing. # Create a small BertEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -66,8 +67,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -66,8 +67,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
def test_all_encoder_outputs_network_creation(self): def test_all_encoder_outputs_network_creation(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
# Create a small TransformerEncoder for testing. # Create a small BertEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -94,8 +95,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -94,8 +95,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a small TransformerEncoder for testing. # Create a small BertEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -125,8 +126,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -125,8 +126,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
sequence_length = 21 sequence_length = 21
vocab_size = 57 vocab_size = 57
num_types = 7 num_types = 7
# Create a small TransformerEncoder for testing. # Create a small BertEncoder for testing.
test_network = transformer_encoder.TransformerEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=2, num_attention_heads=2,
...@@ -153,9 +154,9 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -153,9 +154,9 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
num_types, size=(batch_size, sequence_length)) num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data]) _ = model.predict([word_id_data, mask_data, type_id_data])
# Creates a TransformerEncoder with max_sequence_length != sequence_length # Creates a BertEncoder with max_sequence_length != sequence_length
max_sequence_length = 128 max_sequence_length = 128
test_network = transformer_encoder.TransformerEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
...@@ -167,8 +168,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -167,8 +168,8 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
outputs = model.predict([word_id_data, mask_data, type_id_data]) outputs = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(outputs[0].shape[1], out_seq_len) self.assertEqual(outputs[0].shape[1], out_seq_len)
# Creates a TransformerEncoder with embedding_width != hidden_size # Creates a BertEncoder with embedding_width != hidden_size
test_network = transformer_encoder.TransformerEncoder( test_network = bert_encoder.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
...@@ -199,7 +200,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -199,7 +200,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
output_range=-1, output_range=-1,
embedding_width=16) embedding_width=16)
network = transformer_encoder.TransformerEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
...@@ -209,8 +210,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -209,8 +210,7 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
self.assertEqual(network.get_config(), expected_config) self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config. # Create another network object from the first object's config.
new_network = transformer_encoder.TransformerEncoder.from_config( new_network = bert_encoder.BertEncoder.from_config(network.get_config())
network.get_config())
# Validate that the config can be forced to JSON. # Validate that the config can be forced to JSON.
_ = new_network.to_json() _ = new_network.to_json()
......
...@@ -404,7 +404,7 @@ def get_bert2bert_layers(params: configs.BERT2BERTConfig): ...@@ -404,7 +404,7 @@ def get_bert2bert_layers(params: configs.BERT2BERTConfig):
target_ids = tf.keras.layers.Input( target_ids = tf.keras.layers.Input(
shape=(None,), name="target_ids", dtype=tf.int32) shape=(None,), name="target_ids", dtype=tf.int32)
bert_config = utils.get_bert_config_from_params(params) bert_config = utils.get_bert_config_from_params(params)
bert_model_layer = networks.TransformerEncoder( bert_model_layer = networks.BertEncoder(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers, num_layers=bert_config.num_hidden_layers,
...@@ -454,7 +454,7 @@ def get_nhnet_layers(params: configs.NHNetConfig): ...@@ -454,7 +454,7 @@ def get_nhnet_layers(params: configs.NHNetConfig):
segment_ids = tf.keras.layers.Input( segment_ids = tf.keras.layers.Input(
shape=(None,), name="segment_ids", dtype=tf.int32) shape=(None,), name="segment_ids", dtype=tf.int32)
bert_config = utils.get_bert_config_from_params(params) bert_config = utils.get_bert_config_from_params(params)
bert_model_layer = networks.TransformerEncoder( bert_model_layer = networks.BertEncoder(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers, num_layers=bert_config.num_hidden_layers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment