Commit 9dadc325 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315792986
parent ef99be0b
...@@ -221,6 +221,7 @@ def pretrain_model(bert_config, ...@@ -221,6 +221,7 @@ def pretrain_model(bert_config,
network=transformer_encoder, network=transformer_encoder,
embedding_table=transformer_encoder.get_embedding_table(), embedding_table=transformer_encoder.get_embedding_table(),
num_classes=2, # The next sentence prediction label has two classes. num_classes=2, # The next sentence prediction label has two classes.
activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq, num_token_predictions=max_predictions_per_seq,
initializer=initializer, initializer=initializer,
output='predictions') output='predictions')
......
...@@ -74,6 +74,7 @@ def instantiate_from_cfg( ...@@ -74,6 +74,7 @@ def instantiate_from_cfg(
classification_heads = [] classification_heads = []
return bert_pretrainer.BertPretrainerV2( return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens, config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range), stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network, encoder_network=encoder_network,
......
...@@ -47,8 +47,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -47,8 +47,8 @@ class BertPretrainer(tf.keras.Model):
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used. "network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM and activation: The activation (if any) to use in the masked LM network.
classification networks. If None, no activation will be used. If None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
...@@ -151,6 +151,8 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -151,6 +151,8 @@ class BertPretrainerV2(tf.keras.Model):
num_masked_tokens: Number of tokens to predict from the masked LM. num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a encoder_network: A transformer network. This network should output a
sequence output and a classification output. sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer. to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder classification_heads: A list of optional head layers to transform on encoder
...@@ -166,6 +168,7 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -166,6 +168,7 @@ class BertPretrainerV2(tf.keras.Model):
self, self,
num_masked_tokens: int, num_masked_tokens: int,
encoder_network: tf.keras.Model, encoder_network: tf.keras.Model,
mlm_activation=None,
mlm_initializer='glorot_uniform', mlm_initializer='glorot_uniform',
classification_heads: Optional[List[tf.keras.layers.Layer]] = None, classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
name: str = 'bert', name: str = 'bert',
...@@ -194,6 +197,7 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -194,6 +197,7 @@ class BertPretrainerV2(tf.keras.Model):
num_predictions=num_masked_tokens, num_predictions=num_masked_tokens,
input_width=sequence_output.shape[-1], input_width=sequence_output.shape[-1],
source_network=self.encoder_network, source_network=self.encoder_network,
activation=mlm_activation,
initializer=mlm_initializer, initializer=mlm_initializer,
name='masked_lm') name='masked_lm')
masked_lm_positions = copy.copy(self.masked_lm.inputs[-1]) masked_lm_positions = copy.copy(self.masked_lm.inputs[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment