Commit 6e0b2ccf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make bigbird encoder config more consistent with the encoder class.

PiperOrigin-RevId: 353723657
parent fb9f9ee6
...@@ -136,7 +136,7 @@ class BigBirdEncoderConfig(hyperparams.Config): ...@@ -136,7 +136,7 @@ class BigBirdEncoderConfig(hyperparams.Config):
block_size: int = 64 block_size: int = 64
type_vocab_size: int = 16 type_vocab_size: int = 16
initializer_range: float = 0.02 initializer_range: float = 0.02
embedding_size: Optional[int] = None embedding_width: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
...@@ -290,11 +290,11 @@ def build_encoder(config: EncoderConfig, ...@@ -290,11 +290,11 @@ def build_encoder(config: EncoderConfig,
attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks, num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size, block_size=encoder_cfg.block_size,
max_sequence_length=encoder_cfg.max_position_embeddings, max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size, type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size) embedding_width=encoder_cfg.embedding_width)
if encoder_type == "xlnet": if encoder_type == "xlnet":
return encoder_cls( return encoder_cls(
......
...@@ -36,9 +36,10 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -36,9 +36,10 @@ class BigBirdEncoder(tf.keras.Model):
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads. hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can max_position_embeddings: The maximum length of position embeddings that this
consume. If None, max_sequence_length uses the value from sequence length. encoder can consume. If None, max_position_embeddings uses the value from
This determines the variable shape for positional embeddings. sequence length. This determines the variable shape for positional
embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take. type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers. intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers. activation: The activation to use for the transformer layers.
...@@ -58,7 +59,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -58,7 +59,7 @@ class BigBirdEncoder(tf.keras.Model):
hidden_size=768, hidden_size=768,
num_layers=12, num_layers=12,
num_attention_heads=12, num_attention_heads=12,
max_sequence_length=attention.MAX_SEQ_LEN, max_position_embeddings=attention.MAX_SEQ_LEN,
type_vocab_size=16, type_vocab_size=16,
intermediate_size=3072, intermediate_size=3072,
block_size=64, block_size=64,
...@@ -78,7 +79,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -78,7 +79,7 @@ class BigBirdEncoder(tf.keras.Model):
'hidden_size': hidden_size, 'hidden_size': hidden_size,
'num_layers': num_layers, 'num_layers': num_layers,
'num_attention_heads': num_attention_heads, 'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length, 'max_position_embeddings': max_position_embeddings,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size, 'intermediate_size': intermediate_size,
'block_size': block_size, 'block_size': block_size,
...@@ -109,7 +110,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -109,7 +110,7 @@ class BigBirdEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding( self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer, initializer=initializer,
max_length=max_sequence_length, max_length=max_position_embeddings,
name='position_embedding') name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings) position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding( self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
...@@ -159,7 +160,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -159,7 +160,7 @@ class BigBirdEncoder(tf.keras.Model):
from_block_size=block_size, from_block_size=block_size,
to_block_size=block_size, to_block_size=block_size,
num_rand_blocks=num_rand_blocks, num_rand_blocks=num_rand_blocks,
max_rand_mask_length=max_sequence_length, max_rand_mask_length=max_position_embeddings,
seed=i), seed=i),
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate, attention_dropout_rate=dropout_rate,
......
...@@ -27,7 +27,7 @@ class BigBirdEncoderTest(tf.test.TestCase): ...@@ -27,7 +27,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
network = encoder.BigBirdEncoder( network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096) num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint( word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length)) vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
...@@ -41,7 +41,7 @@ class BigBirdEncoderTest(tf.test.TestCase): ...@@ -41,7 +41,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2 batch_size = 2
vocab_size = 1024 vocab_size = 1024
network = encoder.BigBirdEncoder( network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096) num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint( word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length)) vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length)) mask_data = np.random.randint(2, size=(batch_size, sequence_length))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment