Commit 356c98bd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into detr-push-3

parents d31aba8a b9785623
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
......
...@@ -36,7 +36,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -36,7 +36,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -59,9 +59,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -59,9 +59,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate compilation using explicit output names.""" """Validate compilation using explicit output names."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -81,7 +80,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -81,7 +80,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -101,7 +100,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -101,7 +100,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -12,17 +12,11 @@ ...@@ -12,17 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class BertTokenClassifier(tf.keras.Model): class BertTokenClassifier(tf.keras.Model):
...@@ -77,16 +71,23 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -77,16 +71,23 @@ class BertTokenClassifier(tf.keras.Model):
sequence_output = tf.keras.layers.Dropout( sequence_output = tf.keras.layers.Dropout(
rate=dropout_rate)(sequence_output) rate=dropout_rate)(sequence_output)
self.classifier = networks.TokenClassification( self.classifier = tf.keras.layers.Dense(
input_width=sequence_output.shape[-1], num_classes,
num_classes=num_classes, activation=None,
initializer=initializer, kernel_initializer=initializer,
output=output, name='predictions/transform/logits')
name='classification') self.logits = self.classifier(sequence_output)
predictions = self.classifier(sequence_output) if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(
self.logits)
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
super(BertTokenClassifier, self).__init__( super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=output_tensors, **kwargs)
@property @property
def checkpoint_items(self): def checkpoint_items(self):
......
...@@ -50,7 +50,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -50,7 +50,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size: Size of generator output vocabulary vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network num_classes: Number of classes to predict from the classification network
for the generator network (not used now) for the generator network (not used now)
sequence_length: Input sequence length
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
...@@ -67,7 +66,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -67,7 +66,6 @@ class ElectraPretrainer(tf.keras.Model):
discriminator_network, discriminator_network,
vocab_size, vocab_size,
num_classes, num_classes,
sequence_length,
num_token_predictions, num_token_predictions,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model):
'discriminator_network': discriminator_network, 'discriminator_network': discriminator_network,
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'num_classes': num_classes, 'num_classes': num_classes,
'sequence_length': sequence_length,
'num_token_predictions': num_token_predictions, 'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation, 'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -94,7 +91,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -94,7 +91,6 @@ class ElectraPretrainer(tf.keras.Model):
self.discriminator_network = discriminator_network self.discriminator_network = discriminator_network
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.num_classes = num_classes self.num_classes = num_classes
self.sequence_length = sequence_length
self.num_token_predictions = num_token_predictions self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
......
...@@ -36,9 +36,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -36,9 +36,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -48,7 +52,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -48,7 +52,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
discriminator_network=test_discriminator_network, discriminator_network=test_discriminator_network,
vocab_size=vocab_size, vocab_size=vocab_size,
num_classes=num_classes, num_classes=num_classes,
sequence_length=sequence_length,
num_token_predictions=num_token_predictions, num_token_predictions=num_token_predictions,
disallow_correct=True) disallow_correct=True)
......
...@@ -20,8 +20,5 @@ into two smaller matrices and shares parameters across layers. ...@@ -20,8 +20,5 @@ into two smaller matrices and shares parameters across layers.
intended for use as a classification or regression (if number of classes is set intended for use as a classification or regression (if number of classes is set
to 1) head. to 1) head.
* [`TokenClassification`](token_classification.py) contains a single hidden
layer, and is intended for use as a token classification head.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task. * [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task.
...@@ -17,5 +17,4 @@ from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTran ...@@ -17,5 +17,4 @@ from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTran
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
...@@ -53,9 +53,6 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -53,9 +53,6 @@ class AlbertTransformerEncoder(tf.keras.Model):
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
sequence_length: The sequence length that this encoder expects. If None, the
sequence length is dynamic; if an integer, the encoder will require
sequences padded to this length.
max_sequence_length: The maximum sequence length that this encoder can max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length. consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings. This determines the variable shape for positional embeddings.
...@@ -74,8 +71,7 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -74,8 +71,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
sequence_length=512, max_sequence_length=512,
max_sequence_length=None,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
activation=activations.gelu, activation=activations.gelu,
...@@ -86,8 +82,6 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -86,8 +82,6 @@ class AlbertTransformerEncoder(tf.keras.Model):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
if not max_sequence_length:
max_sequence_length = sequence_length
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
...@@ -95,7 +89,6 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -95,7 +89,6 @@ class AlbertTransformerEncoder(tf.keras.Model):
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'sequence_length': sequence_length,
'max_sequence_length': max_sequence_length, 'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size, 'intermediate_size': intermediate_size,
...@@ -106,11 +99,11 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -106,11 +99,11 @@ class AlbertTransformerEncoder(tf.keras.Model):
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None: if embedding_width is None:
embedding_width = hidden_size embedding_width = hidden_size
......
...@@ -48,7 +48,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -48,7 +48,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
kwargs = dict( kwargs = dict(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
if expected_dtype == tf.float16: if expected_dtype == tf.float16:
...@@ -92,7 +91,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -92,7 +91,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=8, embedding_width=8,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types) type_vocab_size=num_types)
...@@ -123,7 +121,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -123,7 +121,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=8, embedding_width=8,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
...@@ -141,7 +138,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -141,7 +138,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
hidden_size=32, hidden_size=32,
num_layers=3, num_layers=3,
num_attention_heads=2, num_attention_heads=2,
sequence_length=21,
max_sequence_length=21, max_sequence_length=21,
type_vocab_size=12, type_vocab_size=12,
intermediate_size=1223, intermediate_size=1223,
......
...@@ -129,16 +129,17 @@ class EncoderScaffold(tf.keras.Model): ...@@ -129,16 +129,17 @@ class EncoderScaffold(tf.keras.Model):
embeddings, attention_mask = self._embedding_network(inputs) embeddings, attention_mask = self._embedding_network(inputs)
else: else:
self._embedding_network = None self._embedding_network = None
seq_length = embedding_cfg.get('seq_length', None)
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],), shape=(seq_length,),
dtype=tf.int32, dtype=tf.int32,
name='input_word_ids') name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],), shape=(seq_length,),
dtype=tf.int32, dtype=tf.int32,
name='input_mask') name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],), shape=(seq_length,),
dtype=tf.int32, dtype=tf.int32,
name='input_type_ids') name='input_type_ids')
inputs = [word_ids, mask, type_ids] inputs = [word_ids, mask, type_ids]
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classification network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to.
activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
input_width,
num_classes,
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
self.logits = tf.keras.layers.Dense(
num_classes,
activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')(
sequence_data)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
super(TokenClassification, self).__init__(
inputs=[sequence_data], outputs=output_tensors, **kwargs)
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for token classification network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import token_classification
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TokenClassificationTest(keras_parameterized.TestCase):
def test_network_creation(self):
"""Validate that the Keras object can be created."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes)
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
# Validate that the outputs are of the expected shape.
expected_output_shape = [None, sequence_length, num_classes]
self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation(self):
"""Validate that the Keras object can be invoked."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
# Invoke the network as part of a Model.
model = tf.keras.Model(sequence_data, output)
input_data = 10 * np.random.random_sample((3, sequence_length, input_width))
_ = model.predict(input_data)
def test_network_invocation_with_internal_logits(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_internal_and_external_logits(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='logits')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
self.assertAllClose(outputs, logits)
def test_network_invocation_with_logit_output(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
logit_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='logits')
logit_object.set_weights(test_object.get_weights())
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
logit_output = logit_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(sequence_data, logit_output)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
network = token_classification.TokenClassification(
input_width=128,
num_classes=10,
initializer='zeros',
output='predictions')
# Create another network object from the first object's config.
new_network = token_classification.TokenClassification.from_config(
network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = token_classification.TokenClassification(
input_width=128, num_classes=10, output='bad')
if __name__ == '__main__':
tf.test.main()
...@@ -48,9 +48,8 @@ class TransformerEncoder(tf.keras.Model): ...@@ -48,9 +48,8 @@ class TransformerEncoder(tf.keras.Model):
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
sequence_length: The sequence length that this encoder expects. If None, the sequence_length: [Deprecated]. TODO(hongkuny): remove this argument once no
sequence length is dynamic; if an integer, the encoder will require user is using it.
sequences padded to this length.
max_sequence_length: The maximum sequence length that this encoder can max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length. consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings. This determines the variable shape for positional embeddings.
...@@ -83,8 +82,8 @@ class TransformerEncoder(tf.keras.Model): ...@@ -83,8 +82,8 @@ class TransformerEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
sequence_length=512, sequence_length=None,
max_sequence_length=None, max_sequence_length=512,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
activation=activations.gelu, activation=activations.gelu,
...@@ -99,15 +98,12 @@ class TransformerEncoder(tf.keras.Model): ...@@ -99,15 +98,12 @@ class TransformerEncoder(tf.keras.Model):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
if not max_sequence_length:
max_sequence_length = sequence_length
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'sequence_length': sequence_length,
'max_sequence_length': max_sequence_length, 'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size, 'intermediate_size': intermediate_size,
...@@ -121,11 +117,11 @@ class TransformerEncoder(tf.keras.Model): ...@@ -121,11 +117,11 @@ class TransformerEncoder(tf.keras.Model):
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None: if embedding_width is None:
embedding_width = hidden_size embedding_width = hidden_size
......
...@@ -42,7 +42,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -42,7 +42,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -71,7 +70,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -71,7 +70,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
return_all_encoder_outputs=True) return_all_encoder_outputs=True)
...@@ -100,7 +98,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -100,7 +98,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -132,7 +129,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -132,7 +129,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
...@@ -163,7 +159,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -163,7 +159,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
...@@ -177,7 +172,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -177,7 +172,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
...@@ -196,7 +190,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -196,7 +190,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
hidden_size=32, hidden_size=32,
num_layers=3, num_layers=3,
num_attention_heads=2, num_attention_heads=2,
sequence_length=21,
max_sequence_length=21, max_sequence_length=21,
type_vocab_size=12, type_vocab_size=12,
intermediate_size=1223, intermediate_size=1223,
......
...@@ -413,7 +413,6 @@ def get_bert2bert_layers(params: configs.BERT2BERTConfig): ...@@ -413,7 +413,6 @@ def get_bert2bert_layers(params: configs.BERT2BERTConfig):
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=None,
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
......
...@@ -130,13 +130,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -130,13 +130,16 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
weight_decay_rate=0.0, weight_decay_rate=0.0,
include_in_weight_decay=None, include_in_weight_decay=None,
exclude_from_weight_decay=None, exclude_from_weight_decay=None,
gradient_clip_norm=1.0,
name='AdamWeightDecay', name='AdamWeightDecay',
**kwargs): **kwargs):
super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2, super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
epsilon, amsgrad, name, **kwargs) epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate self.weight_decay_rate = weight_decay_rate
self.gradient_clip_norm = gradient_clip_norm
self._include_in_weight_decay = include_in_weight_decay self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay self._exclude_from_weight_decay = exclude_from_weight_decay
logging.info('gradient_clip_norm=%f', gradient_clip_norm)
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
...@@ -165,7 +168,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -165,7 +168,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
name=None, name=None,
experimental_aggregate_gradients=True): experimental_aggregate_gradients=True):
grads, tvars = list(zip(*grads_and_vars)) grads, tvars = list(zip(*grads_and_vars))
if experimental_aggregate_gradients: if experimental_aggregate_gradients and self.gradient_clip_norm > 0.0:
# when experimental_aggregate_gradients = False, apply_gradients() no # when experimental_aggregate_gradients = False, apply_gradients() no
# longer implicitly allreduce gradients, users manually allreduce gradient # longer implicitly allreduce gradients, users manually allreduce gradient
# and passed the allreduced grads_and_vars. For now, the # and passed the allreduced grads_and_vars. For now, the
......
...@@ -18,16 +18,21 @@ import dataclasses ...@@ -18,16 +18,21 @@ import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import electra from official.nlp.configs import electra
from official.nlp.configs import encoders
from official.nlp.data import pretrain_dataloader from official.nlp.data import pretrain_dataloader
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass @dataclasses.dataclass
class ELECTRAPretrainConfig(cfg.TaskConfig): class ElectraPretrainConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
model: electra.ELECTRAPretrainerConfig = electra.ELECTRAPretrainerConfig( model: electra.ElectraPretrainerConfig = electra.ElectraPretrainerConfig(
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, inner_dim=768,
...@@ -39,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig): ...@@ -39,13 +44,44 @@ class ELECTRAPretrainConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(ELECTRAPretrainConfig) def _build_pretrainer(
class ELECTRAPretrainTask(base_task.Task): config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg = config.generator_encoder
discriminator_encoder_cfg = config.discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
discriminator_network = encoders.build_encoder(discriminator_encoder_cfg)
if config.tie_embeddings:
embedding_layer = discriminator_network.get_embedding_layer()
generator_network = encoders.build_encoder(
generator_encoder_cfg, embedding_layer=embedding_layer)
else:
generator_network = encoders.build_encoder(generator_encoder_cfg)
generator_encoder_cfg = generator_encoder_cfg.get()
return models.ElectraPretrainer(
generator_network=generator_network,
discriminator_network=discriminator_network,
vocab_size=generator_encoder_cfg.vocab_size,
num_classes=config.num_classes,
sequence_length=config.sequence_length,
num_token_predictions=config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(
generator_encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=generator_encoder_cfg.initializer_range),
classification_heads=[
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
],
disallow_correct=config.disallow_correct)
@task_factory.register_task_cls(ElectraPretrainConfig)
class ElectraPretrainTask(base_task.Task):
"""ELECTRA Pretrain Task (Masked LM + Replaced Token Detection).""" """ELECTRA Pretrain Task (Masked LM + Replaced Token Detection)."""
def build_model(self): def build_model(self):
return electra.instantiate_pretrainer_from_cfg( return _build_pretrainer(self.task_config.model)
self.task_config.model)
def build_losses(self, def build_losses(self,
labels, labels,
...@@ -69,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task): ...@@ -69,9 +105,7 @@ class ELECTRAPretrainTask(base_task.Task):
sentence_outputs = tf.cast( sentence_outputs = tf.cast(
model_outputs['sentence_outputs'], dtype=tf.float32) model_outputs['sentence_outputs'], dtype=tf.float32)
sentence_loss = tf.keras.losses.sparse_categorical_crossentropy( sentence_loss = tf.keras.losses.sparse_categorical_crossentropy(
sentence_labels, sentence_labels, sentence_outputs, from_logits=True)
sentence_outputs,
from_logits=True)
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
......
...@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader ...@@ -24,15 +24,17 @@ from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import electra_task from official.nlp.tasks import electra_task
class ELECTRAPretrainTaskTest(tf.test.TestCase): class ElectraPretrainTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = electra_task.ELECTRAPretrainConfig( config = electra_task.ElectraPretrainConfig(
model=electra.ELECTRAPretrainerConfig( model=electra.ElectraPretrainerConfig(
generator_encoder=encoders.TransformerEncoderConfig( generator_encoder=encoders.EncoderConfig(
vocab_size=30522, num_layers=1), bert=encoders.BertEncoderConfig(vocab_size=30522,
discriminator_encoder=encoders.TransformerEncoderConfig( num_layers=1)),
vocab_size=30522, num_layers=1), discriminator_encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
num_masked_tokens=20, num_masked_tokens=20,
sequence_length=128, sequence_length=128,
cls_heads=[ cls_heads=[
...@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase): ...@@ -44,7 +46,7 @@ class ELECTRAPretrainTaskTest(tf.test.TestCase):
max_predictions_per_seq=20, max_predictions_per_seq=20,
seq_length=128, seq_length=128,
global_batch_size=1)) global_batch_size=1))
task = electra_task.ELECTRAPretrainTask(config) task = electra_task.ElectraPretrainTask(config)
model = task.build_model() model = task.build_model()
metrics = task.build_metrics() metrics = task.build_metrics()
dataset = task.build_inputs(config.train_data) dataset = task.build_inputs(config.train_data)
......
...@@ -14,21 +14,24 @@ ...@@ -14,21 +14,24 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Masked language task.""" """Masked language task."""
from absl import logging
import dataclasses import dataclasses
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import task_factory
from official.modeling import tf_utils
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert from official.nlp.configs import bert
from official.nlp.configs import encoders
from official.nlp.data import data_loader_factory from official.nlp.data import data_loader_factory
from official.nlp.modeling import layers
from official.nlp.modeling import models
@dataclasses.dataclass @dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig): class MaskedLMConfig(cfg.TaskConfig):
"""The model config.""" """The model config."""
init_checkpoint: str = '' model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
]) ])
...@@ -36,13 +39,23 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -36,13 +39,23 @@ class MaskedLMConfig(cfg.TaskConfig):
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(MaskedLMConfig) @task_factory.register_task_cls(MaskedLMConfig)
class MaskedLMTask(base_task.Task): class MaskedLMTask(base_task.Task):
"""Mock task object for testing.""" """Task object for Mask language modeling."""
def build_model(self, params=None): def build_model(self, params=None):
params = params or self.task_config.model config = params or self.task_config.model
return bert.instantiate_pretrainer_from_cfg(params) encoder_cfg = config.encoder
encoder_network = encoders.build_encoder(encoder_cfg)
cls_heads = [
layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
] if config.cls_heads else []
return models.BertPretrainerV2(
mlm_activation=tf_utils.get_activation(config.mlm_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.mlm_initializer_range),
encoder_network=encoder_network,
classification_heads=cls_heads)
def build_losses(self, def build_losses(self,
labels, labels,
...@@ -64,9 +77,8 @@ class MaskedLMTask(base_task.Task): ...@@ -64,9 +77,8 @@ class MaskedLMTask(base_task.Task):
sentence_outputs = tf.cast( sentence_outputs = tf.cast(
model_outputs['next_sentence'], dtype=tf.float32) model_outputs['next_sentence'], dtype=tf.float32)
sentence_loss = tf.reduce_mean( sentence_loss = tf.reduce_mean(
tf.keras.losses.sparse_categorical_crossentropy(sentence_labels, tf.keras.losses.sparse_categorical_crossentropy(
sentence_outputs, sentence_labels, sentence_outputs, from_logits=True))
from_logits=True))
metrics['next_sentence_loss'].update_state(sentence_loss) metrics['next_sentence_loss'].update_state(sentence_loss)
total_loss = mlm_loss + sentence_loss total_loss = mlm_loss + sentence_loss
else: else:
...@@ -174,17 +186,3 @@ class MaskedLMTask(base_task.Task): ...@@ -174,17 +186,3 @@ class MaskedLMTask(base_task.Task):
aux_losses=model.losses) aux_losses=model.losses)
self.process_metrics(metrics, inputs, outputs) self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss} return {self.loss: loss}
def initialize(self, model: tf.keras.Model):
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
...@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -28,8 +28,10 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = masked_lm.MaskedLMConfig( config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
model=bert.BertPretrainerConfig( model=bert.PretrainerConfig(
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
num_layers=1)),
cls_heads=[ cls_heads=[
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment