Commit f1b9d43f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Add dict outputs to BertEncoder and EncoderScaffold

PiperOrigin-RevId: 330817736
parent 54d2baa4
...@@ -72,7 +72,11 @@ class BertClassifier(tf.keras.Model): ...@@ -72,7 +72,11 @@ class BertClassifier(tf.keras.Model):
if use_encoder_pooler: if use_encoder_pooler:
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs) outputs = network(inputs)
if isinstance(outputs, list):
cls_output = outputs[1]
else:
cls_output = outputs['pooled_output']
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification( self.classifier = networks.Classification(
...@@ -83,7 +87,11 @@ class BertClassifier(tf.keras.Model): ...@@ -83,7 +87,11 @@ class BertClassifier(tf.keras.Model):
name='sentence_prediction') name='sentence_prediction')
predictions = self.classifier(cls_output) predictions = self.classifier(cls_output)
else: else:
sequence_output, _ = network(inputs) outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
self.classifier = layers.ClassificationHead( self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1], inner_dim=sequence_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
......
...@@ -31,13 +31,15 @@ from official.nlp.modeling.models import bert_classifier ...@@ -31,13 +31,15 @@ from official.nlp.modeling.models import bert_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertClassifierTest(keras_parameterized.TestCase): class BertClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(1, 3) @parameterized.named_parameters(('single_cls', 1, False), ('3_cls', 3, False),
def test_bert_trainer(self, num_classes): ('3_cls_dictoutputs', 3, True))
def test_bert_trainer(self, num_classes, dict_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2) test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
......
...@@ -183,7 +183,11 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -183,7 +183,11 @@ class BertPretrainerV2(tf.keras.Model):
self.encoder_network = encoder_network self.encoder_network = encoder_network
inputs = copy.copy(self.encoder_network.inputs) inputs = copy.copy(self.encoder_network.inputs)
sequence_output, _ = self.encoder_network(inputs) outputs = self.encoder_network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
self.classification_heads = classification_heads or [] self.classification_heads = classification_heads or []
if len(set([cls.name for cls in self.classification_heads])) != len( if len(set([cls.name for cls in self.classification_heads])) != len(
......
...@@ -12,12 +12,9 @@ ...@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for BERT trainer network.""" """Tests for BERT pretrainer model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -111,7 +108,8 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -111,7 +108,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self.assertAllEqual(bert_trainer_model.get_config(), self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config()) new_bert_trainer_model.get_config())
def test_bert_pretrainerv2(self): @parameterized.parameters(True, False)
def test_bert_pretrainerv2(self, dict_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
...@@ -119,7 +117,8 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -119,7 +117,8 @@ class BertPretrainerTest(keras_parameterized.TestCase):
test_network = networks.BertEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
......
...@@ -64,7 +64,11 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -64,7 +64,11 @@ class BertSpanLabeler(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs) outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
# This is an instance variable for ease of access to the underlying task # This is an instance variable for ease of access to the underlying task
# network. # network.
......
...@@ -14,10 +14,7 @@ ...@@ -14,10 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for BERT trainer network.""" """Tests for BERT trainer network."""
from __future__ import absolute_import from absl.testing import parameterized
from __future__ import division
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -30,12 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler ...@@ -30,12 +27,14 @@ from official.nlp.modeling.models import bert_span_labeler
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertSpanLabelerTest(keras_parameterized.TestCase): class BertSpanLabelerTest(keras_parameterized.TestCase):
def test_bert_trainer(self): @parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.BertEncoder(vocab_size=vocab_size, num_layers=2) test_network = networks.BertEncoder(
vocab_size=vocab_size, num_layers=2, dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
......
...@@ -67,7 +67,11 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -67,7 +67,11 @@ class BertTokenClassifier(tf.keras.Model):
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model. # invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs) outputs = network(inputs)
if isinstance(outputs, list):
sequence_output = outputs[0]
else:
sequence_output = outputs['sequence_output']
sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)( sequence_output = tf.keras.layers.Dropout(rate=dropout_rate)(
sequence_output) sequence_output)
......
...@@ -12,12 +12,9 @@ ...@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for BERT trainer network.""" """Tests for BERT token classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -30,7 +27,8 @@ from official.nlp.modeling.models import bert_token_classifier ...@@ -30,7 +27,8 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase): class BertTokenClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self): @parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
...@@ -38,7 +36,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -38,7 +36,8 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
test_network = networks.BertEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length) max_sequence_length=sequence_length,
dict_outputs=dict_outputs)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
# ============================================================================== # ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.""" """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
...@@ -64,6 +60,7 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -64,6 +60,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
attention_dropout_rate: The dropout rate to use for the attention layers attention_dropout_rate: The dropout rate to use for the attention layers
within the transformer layers. within the transformer layers.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
dict_outputs: Whether to use a dictionary as the model outputs.
""" """
def __init__(self, def __init__(self,
...@@ -79,6 +76,7 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -79,6 +76,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
dropout_rate=0.1, dropout_rate=0.1,
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
dict_outputs=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -172,9 +170,16 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -172,9 +170,16 @@ class AlbertTransformerEncoder(tf.keras.Model):
kernel_initializer=initializer, kernel_initializer=initializer,
name='pooler_transform')( name='pooler_transform')(
first_token_tensor) first_token_tensor)
if dict_outputs:
outputs = dict(
sequence_output=data,
pooled_output=cls_output,
)
else:
outputs = [data, cls_output]
super(AlbertTransformerEncoder, self).__init__( super(AlbertTransformerEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=[data, cls_output], **kwargs) inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
def get_embedding_table(self): def get_embedding_table(self):
return self._embedding_layer.embeddings return self._embedding_layer.embeddings
......
...@@ -109,7 +109,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -109,7 +109,7 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint( type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length)) num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data]) list_outputs = model.predict([word_id_data, mask_data, type_id_data])
# Creates a TransformerEncoder with max_sequence_length != sequence_length # Creates a TransformerEncoder with max_sequence_length != sequence_length
max_sequence_length = 128 max_sequence_length = 128
...@@ -124,6 +124,27 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -124,6 +124,27 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled]) model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled])
_ = model.predict([word_id_data, mask_data, type_id_data]) _ = model.predict([word_id_data, mask_data, type_id_data])
# Tests dictionary outputs.
test_network_dict = albert_transformer_encoder.AlbertTransformerEncoder(
vocab_size=vocab_size,
embedding_width=8,
hidden_size=hidden_size,
max_sequence_length=max_sequence_length,
num_attention_heads=2,
num_layers=3,
type_vocab_size=num_types,
dict_outputs=True)
_ = test_network_dict([word_ids, mask, type_ids])
test_network_dict.set_weights(test_network.get_weights())
list_outputs = test_network([word_id_data, mask_data, type_id_data])
dict_outputs = test_network_dict(
dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data))
self.assertAllEqual(list_outputs[0], dict_outputs["sequence_output"])
self.assertAllEqual(list_outputs[1], dict_outputs["pooled_output"])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
tf.keras.mixed_precision.experimental.set_policy("mixed_float16") tf.keras.mixed_precision.experimental.set_policy("mixed_float16")
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
......
...@@ -71,6 +71,7 @@ class BertEncoder(tf.keras.Model): ...@@ -71,6 +71,7 @@ class BertEncoder(tf.keras.Model):
embedding layer. Otherwise, we will reuse the given embedding layer. This embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings. generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
""" """
def __init__(self, def __init__(self,
...@@ -90,6 +91,7 @@ class BertEncoder(tf.keras.Model): ...@@ -90,6 +91,7 @@ class BertEncoder(tf.keras.Model):
output_range=None, output_range=None,
embedding_width=None, embedding_width=None,
embedding_layer=None, embedding_layer=None,
dict_outputs=False,
**kwargs): **kwargs):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
...@@ -110,6 +112,7 @@ class BertEncoder(tf.keras.Model): ...@@ -110,6 +112,7 @@ class BertEncoder(tf.keras.Model):
'return_all_encoder_outputs': return_all_encoder_outputs, 'return_all_encoder_outputs': return_all_encoder_outputs,
'output_range': output_range, 'output_range': output_range,
'embedding_width': embedding_width, 'embedding_width': embedding_width,
'dict_outputs': dict_outputs,
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
...@@ -197,11 +200,16 @@ class BertEncoder(tf.keras.Model): ...@@ -197,11 +200,16 @@ class BertEncoder(tf.keras.Model):
name='pooler_transform') name='pooler_transform')
cls_output = self._pooler_layer(first_token_tensor) cls_output = self._pooler_layer(first_token_tensor)
if return_all_encoder_outputs: if dict_outputs:
outputs = dict(
sequence_output=encoder_outputs[-1],
pooled_output=cls_output,
encoder_outputs=encoder_outputs,
)
elif return_all_encoder_outputs:
outputs = [encoder_outputs, cls_output] outputs = [encoder_outputs, cls_output]
else: else:
outputs = [encoder_outputs[-1], cls_output] outputs = [encoder_outputs[-1], cls_output]
super(BertEncoder, self).__init__( super(BertEncoder, self).__init__(
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
......
...@@ -12,11 +12,7 @@ ...@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for transformer-based text encoder network.""" """Tests for transformer-based bert encoder network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import libraries # Import libraries
from absl.testing import parameterized from absl.testing import parameterized
...@@ -64,6 +60,35 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -64,6 +60,35 @@ class BertEncoderTest(keras_parameterized.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(tf.float32, pooled.dtype) self.assertAllEqual(tf.float32, pooled.dtype)
test_network_dict = bert_encoder.BertEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=3,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit).
inputs = dict(
input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)
_ = test_network_dict(inputs)
test_network_dict.set_weights(test_network.get_weights())
batch_size = 2
vocab_size = 100
num_types = 2
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length))
list_outputs = test_network([word_id_data, mask_data, type_id_data])
dict_outputs = test_network_dict(
dict(
input_word_ids=word_id_data,
input_mask=mask_data,
input_type_ids=type_id_data))
self.assertAllEqual(list_outputs[0], dict_outputs["sequence_output"])
self.assertAllEqual(list_outputs[1], dict_outputs["pooled_output"])
def test_all_encoder_outputs_network_creation(self): def test_all_encoder_outputs_network_creation(self):
hidden_size = 32 hidden_size = 32
sequence_length = 21 sequence_length = 21
...@@ -199,7 +224,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -199,7 +224,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
initializer="glorot_uniform", initializer="glorot_uniform",
return_all_encoder_outputs=False, return_all_encoder_outputs=False,
output_range=-1, output_range=-1,
embedding_width=16) embedding_width=16,
dict_outputs=True)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
......
...@@ -93,6 +93,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -93,6 +93,7 @@ class EncoderScaffold(tf.keras.Model):
"kernel_initializer": The initializer for the transformer layers. "kernel_initializer": The initializer for the transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
dict_outputs: Whether to use a dictionary as the model outputs.
""" """
def __init__(self, def __init__(self,
...@@ -106,6 +107,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -106,6 +107,7 @@ class EncoderScaffold(tf.keras.Model):
hidden_cls=layers.Transformer, hidden_cls=layers.Transformer,
hidden_cfg=None, hidden_cfg=None,
return_all_layer_outputs=False, return_all_layer_outputs=False,
dict_outputs=False,
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._hidden_cls = hidden_cls self._hidden_cls = hidden_cls
...@@ -117,6 +119,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -117,6 +119,7 @@ class EncoderScaffold(tf.keras.Model):
self._embedding_cfg = embedding_cfg self._embedding_cfg = embedding_cfg
self._embedding_data = embedding_data self._embedding_data = embedding_data
self._return_all_layer_outputs = return_all_layer_outputs self._return_all_layer_outputs = return_all_layer_outputs
self._dict_outputs = dict_outputs
self._kwargs = kwargs self._kwargs = kwargs
if embedding_cls: if embedding_cls:
...@@ -200,7 +203,13 @@ class EncoderScaffold(tf.keras.Model): ...@@ -200,7 +203,13 @@ class EncoderScaffold(tf.keras.Model):
name='cls_transform') name='cls_transform')
cls_output = self._pooler_layer(first_token_tensor) cls_output = self._pooler_layer(first_token_tensor)
if return_all_layer_outputs: if dict_outputs:
outputs = dict(
sequence_output=layer_output_data[-1],
pooled_output=cls_output,
encoder_outputs=layer_output_data,
)
elif return_all_layer_outputs:
outputs = [layer_output_data, cls_output] outputs = [layer_output_data, cls_output]
else: else:
outputs = [layer_output_data[-1], cls_output] outputs = [layer_output_data[-1], cls_output]
...@@ -219,6 +228,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -219,6 +228,7 @@ class EncoderScaffold(tf.keras.Model):
'embedding_cfg': self._embedding_cfg, 'embedding_cfg': self._embedding_cfg,
'hidden_cfg': self._hidden_cfg, 'hidden_cfg': self._hidden_cfg,
'return_all_layer_outputs': self._return_all_layer_outputs, 'return_all_layer_outputs': self._return_all_layer_outputs,
'dict_outputs': self._dict_outputs,
} }
if inspect.isclass(self._hidden_cls): if inspect.isclass(self._hidden_cls):
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
......
...@@ -12,11 +12,7 @@ ...@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for transformer-based text encoder network.""" """Tests for EncoderScaffold network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
...@@ -218,16 +214,17 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -218,16 +214,17 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal( pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg,
dict_outputs=True)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
data, pooled = test_network([word_ids, mask, type_ids]) outputs = test_network([word_ids, mask, type_ids])
# Create a model based off of this network: # Create a model based off of this network:
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled]) model = tf.keras.Model([word_ids, mask, type_ids], outputs)
# Invoke the model. We can't validate the output data here (the model is too # Invoke the model. We can't validate the output data here (the model is too
# complex) but this will catch structural runtime errors. # complex) but this will catch structural runtime errors.
...@@ -237,7 +234,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -237,7 +234,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
type_id_data = np.random.randint( type_id_data = np.random.randint(
num_types, size=(batch_size, sequence_length)) num_types, size=(batch_size, sequence_length))
_ = model.predict([word_id_data, mask_data, type_id_data]) preds = model.predict([word_id_data, mask_data, type_id_data])
self.assertEqual(preds["pooled_output"].shape, (3, hidden_size))
# Creates a EncoderScaffold with max_sequence_length != sequence_length # Creates a EncoderScaffold with max_sequence_length != sequence_length
num_types = 7 num_types = 7
...@@ -272,8 +270,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase): ...@@ -272,8 +270,8 @@ class EncoderScaffoldLayerClassTest(keras_parameterized.TestCase):
stddev=0.02), stddev=0.02),
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
embedding_cfg=embedding_cfg) embedding_cfg=embedding_cfg)
outputs = test_network([word_ids, mask, type_ids])
model = tf.keras.Model([word_ids, mask, type_ids], [data, pooled]) model = tf.keras.Model([word_ids, mask, type_ids], outputs)
_ = model.predict([word_id_data, mask_data, type_id_data]) _ = model.predict([word_id_data, mask_data, type_id_data])
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment