Commit 390c7a93 authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 398257451
parent 10df8b1d
...@@ -204,6 +204,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -204,6 +204,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig() bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig() kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
...@@ -436,6 +437,40 @@ def build_encoder(config: EncoderConfig, ...@@ -436,6 +437,40 @@ def build_encoder(config: EncoderConfig,
initializer=tf.keras.initializers.RandomNormal( initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range)) stddev=encoder_cfg.initializer_range))
if encoder_type == "teams":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
embedding_width=encoder_cfg.embedding_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return networks.BertEncoder( return networks.BertEncoder(
......
...@@ -61,7 +61,6 @@ def bert_sentence_prediction() -> cfg.ExperimentConfig: ...@@ -61,7 +61,6 @@ def bert_sentence_prediction() -> cfg.ExperimentConfig:
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
config.task.model.encoder.type = 'bert'
return config return config
...@@ -98,7 +97,6 @@ def bert_squad() -> cfg.ExperimentConfig: ...@@ -98,7 +97,6 @@ def bert_squad() -> cfg.ExperimentConfig:
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
config.task.model.encoder.type = 'bert'
return config return config
......
...@@ -62,6 +62,8 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -62,6 +62,8 @@ class PackedSequenceEmbedding(tf.keras.Model):
pack_multiple_sequences=False, pack_multiple_sequences=False,
**kwargs): **kwargs):
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
config_dict = { config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
......
task: task:
model: model:
encoder: encoder:
bert: teams:
attention_dropout_rate: 0.1 attention_dropout_rate: 0.1
dropout_rate: 0.1 dropout_rate: 0.1
embedding_size: 768 embedding_size: 768
...@@ -14,3 +14,4 @@ task: ...@@ -14,3 +14,4 @@ task:
num_layers: 12 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
type: teams
task: task:
model: model:
encoder: encoder:
bert: teams:
attention_dropout_rate: 0.1 attention_dropout_rate: 0.1
dropout_rate: 0.1 dropout_rate: 0.1
embedding_size: 128 embedding_size: 128
...@@ -14,3 +14,4 @@ task: ...@@ -14,3 +14,4 @@ task:
num_layers: 12 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
type: teams
...@@ -64,9 +64,6 @@ def get_encoder(bert_config, ...@@ -64,9 +64,6 @@ def get_encoder(bert_config,
Returns: Returns:
A encoder object. A encoder object.
""" """
# embedding_size is required for PackedSequenceEmbedding.
if bert_config.embedding_size is None:
bert_config.embedding_size = bert_config.hidden_size
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment