Commit 9824b44e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up sequence_length usage for TransformerEncoder.

PiperOrigin-RevId: 326541553
parent 5329de23
...@@ -36,7 +36,9 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -36,7 +36,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -91,7 +93,7 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -91,7 +93,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
...@@ -115,7 +117,9 @@ class BertPretrainerTest(keras_parameterized.TestCase): ...@@ -115,7 +117,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_pretrainer.BertPretrainerV2( bert_trainer_model = bert_pretrainer.BertPretrainerV2(
......
...@@ -36,7 +36,9 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -36,7 +36,9 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
...@@ -61,7 +63,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -61,7 +63,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2) vocab_size=100, num_layers=2, max_sequence_length=2)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_token_classifier.BertTokenClassifier( bert_trainer_model = bert_token_classifier.BertTokenClassifier(
...@@ -82,7 +84,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -82,7 +84,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_network = networks.TransformerEncoder( test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5) vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args # Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
...@@ -92,9 +92,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -92,9 +92,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer. (Here, we # Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.) # use a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. # Create a ELECTRA trainer with the created network.
eletrca_trainer_model = electra_pretrainer.ElectraPretrainer( eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
...@@ -129,9 +129,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -129,9 +129,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder( test_generator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder( test_discriminator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3) vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. (Note that all the args # Create a ELECTRA trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.) # are different, so we can catch any serialization mismatches.)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment