Commit 6e0b2ccf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Make bigbird encoder config more consistent with the encoder class.

PiperOrigin-RevId: 353723657
parent fb9f9ee6
......@@ -136,7 +136,7 @@ class BigBirdEncoderConfig(hyperparams.Config):
block_size: int = 64
type_vocab_size: int = 16
initializer_range: float = 0.02
embedding_size: Optional[int] = None
embedding_width: Optional[int] = None
@dataclasses.dataclass
......@@ -290,11 +290,11 @@ def build_encoder(config: EncoderConfig,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size,
max_sequence_length=encoder_cfg.max_position_embeddings,
max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_size)
embedding_width=encoder_cfg.embedding_width)
if encoder_type == "xlnet":
return encoder_cls(
......
......@@ -36,9 +36,10 @@ class BigBirdEncoder(tf.keras.Model):
num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The
hidden size must be divisible by the number of attention heads.
max_sequence_length: The maximum sequence length that this encoder can
consume. If None, max_sequence_length uses the value from sequence length.
This determines the variable shape for positional embeddings.
max_position_embeddings: The maximum length of position embeddings that this
encoder can consume. If None, max_position_embeddings uses the value from
sequence length. This determines the variable shape for positional
embeddings.
type_vocab_size: The number of types that the 'type_ids' input can take.
intermediate_size: The intermediate size for the transformer layers.
activation: The activation to use for the transformer layers.
......@@ -58,7 +59,7 @@ class BigBirdEncoder(tf.keras.Model):
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_sequence_length=attention.MAX_SEQ_LEN,
max_position_embeddings=attention.MAX_SEQ_LEN,
type_vocab_size=16,
intermediate_size=3072,
block_size=64,
......@@ -78,7 +79,7 @@ class BigBirdEncoder(tf.keras.Model):
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'max_sequence_length': max_sequence_length,
'max_position_embeddings': max_position_embeddings,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'block_size': block_size,
......@@ -109,7 +110,7 @@ class BigBirdEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = keras_nlp.layers.PositionEmbedding(
initializer=initializer,
max_length=max_sequence_length,
max_length=max_position_embeddings,
name='position_embedding')
position_embeddings = self._position_embedding_layer(word_embeddings)
self._type_embedding_layer = keras_nlp.layers.OnDeviceEmbedding(
......@@ -159,7 +160,7 @@ class BigBirdEncoder(tf.keras.Model):
from_block_size=block_size,
to_block_size=block_size,
num_rand_blocks=num_rand_blocks,
max_rand_mask_length=max_sequence_length,
max_rand_mask_length=max_position_embeddings,
seed=i),
dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate,
......
......@@ -27,7 +27,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
......@@ -41,7 +41,7 @@ class BigBirdEncoderTest(tf.test.TestCase):
batch_size = 2
vocab_size = 1024
network = encoder.BigBirdEncoder(
num_layers=1, vocab_size=1024, max_sequence_length=4096)
num_layers=1, vocab_size=1024, max_position_embeddings=4096)
word_id_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
mask_data = np.random.randint(2, size=(batch_size, sequence_length))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment