Commit a62c2bfc authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 324137297
parent 76f760c4
...@@ -86,7 +86,7 @@ def _create_albert_model(cfg): ...@@ -86,7 +86,7 @@ def _create_albert_model(cfg):
activation=activations.gelu, activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob, dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob, attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings, max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size, type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range)) stddev=cfg.initializer_range))
......
...@@ -104,14 +104,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): ...@@ -104,14 +104,14 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
@gin.configurable @gin.configurable
def get_transformer_encoder(bert_config, def get_transformer_encoder(bert_config,
sequence_length, sequence_length=None,
transformer_encoder_cls=None, transformer_encoder_cls=None,
output_range=None): output_range=None):
"""Gets a 'TransformerEncoder' object. """Gets a 'TransformerEncoder' object.
Args: Args:
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object. bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
sequence_length: Maximum sequence length of the training data. sequence_length: [Deprecated].
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
default BERT encoder implementation. default BERT encoder implementation.
output_range: the sequence output range, [0, output_range). Default setting output_range: the sequence output range, [0, output_range). Default setting
...@@ -120,13 +120,13 @@ def get_transformer_encoder(bert_config, ...@@ -120,13 +120,13 @@ def get_transformer_encoder(bert_config,
Returns: Returns:
A networks.TransformerEncoder object. A networks.TransformerEncoder object.
""" """
del sequence_length
if transformer_encoder_cls is not None: if transformer_encoder_cls is not None:
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin. # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
hidden_size=bert_config.hidden_size, hidden_size=bert_config.hidden_size,
seq_length=sequence_length,
max_seq_length=bert_config.max_position_embeddings, max_seq_length=bert_config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range),
...@@ -161,7 +161,6 @@ def get_transformer_encoder(bert_config, ...@@ -161,7 +161,6 @@ def get_transformer_encoder(bert_config,
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=sequence_length,
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
embedding_width=bert_config.embedding_size, embedding_width=bert_config.embedding_size,
......
...@@ -56,8 +56,6 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,8 +56,6 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from encoder: sequence and classification output. # Expect two output from encoder: sequence and classification output.
self.assertIsInstance(encoder.output, list) self.assertIsInstance(encoder.output, list)
self.assertLen(encoder.output, 2) self.assertLen(encoder.output, 2)
# shape should be [batch size, seq_length, hidden_size]
self.assertEqual(encoder.output[0].shape.as_list(), [None, 5, 16])
# shape should be [batch size, hidden_size] # shape should be [batch size, hidden_size]
self.assertEqual(encoder.output[1].shape.as_list(), [None, 16]) self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
...@@ -74,16 +72,12 @@ class BertModelsTest(tf.test.TestCase): ...@@ -74,16 +72,12 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from model: start positions and end positions # Expect two output from model: start positions and end positions
self.assertIsInstance(model.output, list) self.assertIsInstance(model.output, list)
self.assertLen(model.output, 2) self.assertLen(model.output, 2)
# shape should be [batch size, seq_length]
self.assertEqual(model.output[0].shape.as_list(), [None, 5])
# shape should be [batch size, seq_length]
self.assertEqual(model.output[1].shape.as_list(), [None, 5])
# Expect two output from core_model: sequence and classification output. # Expect two output from core_model: sequence and classification output.
self.assertIsInstance(core_model.output, list) self.assertIsInstance(core_model.output, list)
self.assertLen(core_model.output, 2) self.assertLen(core_model.output, 2)
# shape should be [batch size, seq_length, hidden_size] # shape should be [batch size, None, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, 5, 16]) self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
# shape should be [batch size, hidden_size] # shape should be [batch size, hidden_size]
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16]) self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
...@@ -104,8 +98,8 @@ class BertModelsTest(tf.test.TestCase): ...@@ -104,8 +98,8 @@ class BertModelsTest(tf.test.TestCase):
# Expect two output from core_model: sequence and classification output. # Expect two output from core_model: sequence and classification output.
self.assertIsInstance(core_model.output, list) self.assertIsInstance(core_model.output, list)
self.assertLen(core_model.output, 2) self.assertLen(core_model.output, 2)
# shape should be [batch size, 1, hidden_size] # shape should be [batch size, None, hidden_size]
self.assertEqual(core_model.output[0].shape.as_list(), [None, 1, 16]) self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
# shape should be [batch size, hidden_size] # shape should be [batch size, hidden_size]
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16]) self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
......
...@@ -61,7 +61,7 @@ def _create_bert_model(cfg): ...@@ -61,7 +61,7 @@ def _create_bert_model(cfg):
activation=activations.gelu, activation=activations.gelu,
dropout_rate=cfg.hidden_dropout_prob, dropout_rate=cfg.hidden_dropout_prob,
attention_dropout_rate=cfg.attention_probs_dropout_prob, attention_dropout_rate=cfg.attention_probs_dropout_prob,
sequence_length=cfg.max_position_embeddings, max_sequence_length=cfg.max_position_embeddings,
type_vocab_size=cfg.type_vocab_size, type_vocab_size=cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range), stddev=cfg.initializer_range),
......
...@@ -54,7 +54,6 @@ def instantiate_encoder_from_cfg( ...@@ -54,7 +54,6 @@ def instantiate_encoder_from_cfg(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
seq_length=None,
max_seq_length=config.max_position_embeddings, max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range), stddev=config.initializer_range),
...@@ -90,7 +89,6 @@ def instantiate_encoder_from_cfg( ...@@ -90,7 +89,6 @@ def instantiate_encoder_from_cfg(
activation=tf_utils.get_activation(config.hidden_activation), activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate, dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate, attention_dropout_rate=config.attention_dropout_rate,
sequence_length=None,
max_sequence_length=config.max_position_embeddings, max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size, type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
......
...@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
def create_layer(self, def create_layer(self,
vocab_size, vocab_size,
sequence_length,
hidden_size, hidden_size,
output='predictions', output='predictions',
xformer_stack=None): xformer_stack=None):
...@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
...@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21 num_predictions = 21
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size) hidden_size=hidden_size)
# Make sure that the output tensor of the masked LM is the right shape. # Make sure that the output tensor of the masked LM is the right shape.
...@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='predictions') output='predictions')
logit_layer = self.create_layer( logit_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='logits') output='logits')
...@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21 num_predictions = 21
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size) hidden_size=hidden_size)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
...@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self): def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'): with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_layer( _ = self.create_layer(
vocab_size=8, sequence_length=8, hidden_size=8, output='bad') vocab_size=8, hidden_size=8, output='bad')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -38,7 +38,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -38,7 +38,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -62,7 +62,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
...@@ -83,7 +83,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -83,7 +83,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -94,7 +94,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -94,7 +94,8 @@ class BertPretrainer(tf.keras.Model):
if isinstance(cls_output, list): if isinstance(cls_output, list):
cls_output = cls_output[-1] cls_output = cls_output[-1]
sequence_output_length = sequence_output.shape.as_list()[1] sequence_output_length = sequence_output.shape.as_list()[1]
if sequence_output_length < num_token_predictions: if sequence_output_length is not None and (sequence_output_length <
num_token_predictions):
raise ValueError( raise ValueError(
"The passed network's output length is %s, which is less than the " "The passed network's output length is %s, which is less than the "
'requested num_token_predictions %s.' % 'requested num_token_predictions %s.' %
......
...@@ -36,7 +36,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -36,7 +36,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -59,9 +59,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -59,9 +59,8 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
"""Validate compilation using explicit output names.""" """Validate compilation using explicit output names."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -81,7 +80,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -81,7 +80,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network) bert_trainer_model = bert_span_labeler.BertSpanLabeler(test_network)
...@@ -101,7 +100,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase): ...@@ -101,7 +100,7 @@ class BertSpanLabelerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -50,7 +50,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -50,7 +50,6 @@ class ElectraPretrainer(tf.keras.Model):
vocab_size: Size of generator output vocabulary vocab_size: Size of generator output vocabulary
num_classes: Number of classes to predict from the classification network num_classes: Number of classes to predict from the classification network
for the generator network (not used now) for the generator network (not used now)
sequence_length: Input sequence length
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
mlm_activation: The activation (if any) to use in the masked LM and mlm_activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
...@@ -67,7 +66,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -67,7 +66,6 @@ class ElectraPretrainer(tf.keras.Model):
discriminator_network, discriminator_network,
vocab_size, vocab_size,
num_classes, num_classes,
sequence_length,
num_token_predictions, num_token_predictions,
mlm_activation=None, mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
...@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -80,7 +78,6 @@ class ElectraPretrainer(tf.keras.Model):
'discriminator_network': discriminator_network, 'discriminator_network': discriminator_network,
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'num_classes': num_classes, 'num_classes': num_classes,
'sequence_length': sequence_length,
'num_token_predictions': num_token_predictions, 'num_token_predictions': num_token_predictions,
'mlm_activation': mlm_activation, 'mlm_activation': mlm_activation,
'mlm_initializer': mlm_initializer, 'mlm_initializer': mlm_initializer,
...@@ -94,7 +91,6 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -94,7 +91,6 @@ class ElectraPretrainer(tf.keras.Model):
self.discriminator_network = discriminator_network self.discriminator_network = discriminator_network
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.num_classes = num_classes self.num_classes = num_classes
self.sequence_length = sequence_length
self.num_token_predictions = num_token_predictions self.num_token_predictions = num_token_predictions
self.mlm_activation = mlm_activation self.mlm_activation = mlm_activation
self.mlm_initializer = mlm_initializer self.mlm_initializer = mlm_initializer
......
...@@ -36,9 +36,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -36,9 +36,13 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -48,7 +52,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -48,7 +52,6 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
discriminator_network=test_discriminator_network, discriminator_network=test_discriminator_network,
vocab_size=vocab_size, vocab_size=vocab_size,
num_classes=num_classes, num_classes=num_classes,
sequence_length=sequence_length,
num_token_predictions=num_token_predictions, num_token_predictions=num_token_predictions,
disallow_correct=True) disallow_correct=True)
......
...@@ -53,9 +53,6 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -53,9 +53,6 @@ class AlbertTransformerEncoder(tf.keras.Model):
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
sequence_length: The sequence length that this encoder expects. If None, the
sequence length is dynamic; if an integer, the encoder will require
sequences padded to this length.
max_sequence_length: The maximum sequence length that this encoder can max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length. consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings. This determines the variable shape for positional embeddings.
...@@ -74,8 +71,7 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -74,8 +71,7 @@ class AlbertTransformerEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
sequence_length=512, max_sequence_length=512,
max_sequence_length=None,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
activation=activations.gelu, activation=activations.gelu,
...@@ -86,8 +82,6 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -86,8 +82,6 @@ class AlbertTransformerEncoder(tf.keras.Model):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
if not max_sequence_length:
max_sequence_length = sequence_length
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
...@@ -95,7 +89,6 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -95,7 +89,6 @@ class AlbertTransformerEncoder(tf.keras.Model):
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'sequence_length': sequence_length,
'max_sequence_length': max_sequence_length, 'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size, 'intermediate_size': intermediate_size,
...@@ -106,11 +99,11 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -106,11 +99,11 @@ class AlbertTransformerEncoder(tf.keras.Model):
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None: if embedding_width is None:
embedding_width = hidden_size embedding_width = hidden_size
......
...@@ -48,7 +48,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -48,7 +48,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
kwargs = dict( kwargs = dict(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
if expected_dtype == tf.float16: if expected_dtype == tf.float16:
...@@ -92,7 +91,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -92,7 +91,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=8, embedding_width=8,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types) type_vocab_size=num_types)
...@@ -123,7 +121,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -123,7 +121,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=8, embedding_width=8,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
...@@ -141,7 +138,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -141,7 +138,6 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
hidden_size=32, hidden_size=32,
num_layers=3, num_layers=3,
num_attention_heads=2, num_attention_heads=2,
sequence_length=21,
max_sequence_length=21, max_sequence_length=21,
type_vocab_size=12, type_vocab_size=12,
intermediate_size=1223, intermediate_size=1223,
......
...@@ -129,16 +129,17 @@ class EncoderScaffold(tf.keras.Model): ...@@ -129,16 +129,17 @@ class EncoderScaffold(tf.keras.Model):
embeddings, attention_mask = self._embedding_network(inputs) embeddings, attention_mask = self._embedding_network(inputs)
else: else:
self._embedding_network = None self._embedding_network = None
seq_length = embedding_cfg.get('seq_length', None)
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],), shape=(seq_length,),
dtype=tf.int32, dtype=tf.int32,
name='input_word_ids') name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],), shape=(seq_length,),
dtype=tf.int32, dtype=tf.int32,
name='input_mask') name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(embedding_cfg['seq_length'],), shape=(seq_length,),
dtype=tf.int32, dtype=tf.int32,
name='input_type_ids') name='input_type_ids')
inputs = [word_ids, mask, type_ids] inputs = [word_ids, mask, type_ids]
......
...@@ -48,9 +48,8 @@ class TransformerEncoder(tf.keras.Model): ...@@ -48,9 +48,8 @@ class TransformerEncoder(tf.keras.Model):
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
sequence_length: The sequence length that this encoder expects. If None, the sequence_length: [Deprecated]. TODO(hongkuny): remove this argument once no
sequence length is dynamic; if an integer, the encoder will require user is using it.
sequences padded to this length.
max_sequence_length: The maximum sequence length that this encoder can max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length. consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings. This determines the variable shape for positional embeddings.
...@@ -83,8 +82,8 @@ class TransformerEncoder(tf.keras.Model): ...@@ -83,8 +82,8 @@ class TransformerEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
sequence_length=512, sequence_length=None,
max_sequence_length=None, max_sequence_length=512,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
activation=activations.gelu, activation=activations.gelu,
...@@ -99,15 +98,12 @@ class TransformerEncoder(tf.keras.Model): ...@@ -99,15 +98,12 @@ class TransformerEncoder(tf.keras.Model):
activation = tf.keras.activations.get(activation) activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
if not max_sequence_length:
max_sequence_length = sequence_length
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config_dict = { self._config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'sequence_length': sequence_length,
'max_sequence_length': max_sequence_length, 'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size, 'intermediate_size': intermediate_size,
...@@ -121,11 +117,11 @@ class TransformerEncoder(tf.keras.Model): ...@@ -121,11 +117,11 @@ class TransformerEncoder(tf.keras.Model):
} }
word_ids = tf.keras.layers.Input( word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids') shape=(None,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input( mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_mask') shape=(None,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
if embedding_width is None: if embedding_width is None:
embedding_width = hidden_size embedding_width = hidden_size
......
...@@ -42,7 +42,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -42,7 +42,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -71,7 +70,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -71,7 +70,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
return_all_encoder_outputs=True) return_all_encoder_outputs=True)
...@@ -100,7 +98,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -100,7 +98,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=100, vocab_size=100,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3) num_layers=3)
# Create the inputs (note that the first dimension is implicit). # Create the inputs (note that the first dimension is implicit).
...@@ -132,7 +129,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -132,7 +129,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
type_vocab_size=num_types, type_vocab_size=num_types,
...@@ -163,7 +159,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -163,7 +159,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
...@@ -177,7 +172,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -177,7 +172,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
test_network = transformer_encoder.TransformerEncoder( test_network = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
hidden_size=hidden_size, hidden_size=hidden_size,
sequence_length=sequence_length,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
num_attention_heads=2, num_attention_heads=2,
num_layers=3, num_layers=3,
...@@ -196,7 +190,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase): ...@@ -196,7 +190,6 @@ class TransformerEncoderTest(keras_parameterized.TestCase):
hidden_size=32, hidden_size=32,
num_layers=3, num_layers=3,
num_attention_heads=2, num_attention_heads=2,
sequence_length=21,
max_sequence_length=21, max_sequence_length=21,
type_vocab_size=12, type_vocab_size=12,
intermediate_size=1223, intermediate_size=1223,
......
...@@ -413,7 +413,6 @@ def get_bert2bert_layers(params: configs.BERT2BERTConfig): ...@@ -413,7 +413,6 @@ def get_bert2bert_layers(params: configs.BERT2BERTConfig):
activation=tf_utils.get_activation(bert_config.hidden_act), activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=None,
max_sequence_length=bert_config.max_position_embeddings, max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment