Commit 390c7a93 authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398257451
parent 10df8b1d
......@@ -204,6 +204,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
......@@ -436,6 +437,40 @@ def build_encoder(config: EncoderConfig,
initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
if encoder_type == "teams":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
embedding_width=encoder_cfg.embedding_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return networks.BertEncoder(
......
......@@ -61,7 +61,6 @@ def bert_sentence_prediction() -> cfg.ExperimentConfig:
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
config.task.model.encoder.type = 'bert'
return config
......@@ -98,7 +97,6 @@ def bert_squad() -> cfg.ExperimentConfig:
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
config.task.model.encoder.type = 'bert'
return config
......
......@@ -62,6 +62,8 @@ class PackedSequenceEmbedding(tf.keras.Model):
pack_multiple_sequences=False,
**kwargs):
initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
config_dict = {
'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size,
......
task:
model:
encoder:
bert:
teams:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 768
......@@ -14,3 +14,4 @@ task:
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
type: teams
task:
model:
encoder:
bert:
teams:
attention_dropout_rate: 0.1
dropout_rate: 0.1
embedding_size: 128
......@@ -14,3 +14,4 @@ task:
num_layers: 12
type_vocab_size: 2
vocab_size: 30522
type: teams
......@@ -64,9 +64,6 @@ def get_encoder(bert_config,
Returns:
A encoder object.
"""
# embedding_size is required for PackedSequenceEmbedding.
if bert_config.embedding_size is None:
bert_config.embedding_size = bert_config.hidden_size
embedding_cfg = dict(
vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment