"docs/EN/vscode:/vscode.git/clone" did not exist on "448fdb426061818cb94c360f19726b8b57bb03cb"
Commit 9824b44e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Clean up sequence_length usage for TransformerEncoder.

PiperOrigin-RevId: 326541553
parent 5329de23
......@@ -36,7 +36,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network.
num_classes = 3
......@@ -91,7 +93,7 @@ class BertPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......@@ -115,7 +117,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_pretrainer.BertPretrainerV2(
......
......@@ -36,7 +36,9 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length)
# Create a BERT trainer with the created network.
num_classes = 3
......@@ -61,7 +63,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
vocab_size=100, num_layers=2, max_sequence_length=2)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
......@@ -82,7 +84,7 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
vocab_size=100, num_layers=2, max_sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
......@@ -92,9 +92,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the ELECTRA trainer. (Here, we
# use a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network.
eletrca_trainer_model = electra_pretrainer.ElectraPretrainer(
......@@ -129,9 +129,9 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_generator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
vocab_size=100, num_layers=4, max_sequence_length=3)
test_discriminator_network = networks.TransformerEncoder(
vocab_size=100, num_layers=4, sequence_length=3)
vocab_size=100, num_layers=4, max_sequence_length=3)
# Create a ELECTRA trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment