Commit 9dadc325 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 315792986
parent ef99be0b
......@@ -221,6 +221,7 @@ def pretrain_model(bert_config,
network=transformer_encoder,
embedding_table=transformer_encoder.get_embedding_table(),
num_classes=2, # The next sentence prediction label has two classes.
activation=tf_utils.get_activation(bert_config.hidden_act),
num_token_predictions=max_predictions_per_seq,
initializer=initializer,
output='predictions')
......
......@@ -74,6 +74,7 @@ def instantiate_from_cfg(
classification_heads = []
return bert_pretrainer.BertPretrainerV2(
config.num_masked_tokens,
mlm_activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
encoder_network=encoder_network,
......
......@@ -47,8 +47,8 @@ class BertPretrainer(tf.keras.Model):
num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
......@@ -151,6 +151,8 @@ class BertPretrainerV2(tf.keras.Model):
num_masked_tokens: Number of tokens to predict from the masked LM.
encoder_network: A transformer network. This network should output a
sequence output and a classification output.
mlm_activation: The activation (if any) to use in the masked LM network.
If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM. Default
to a Glorot uniform initializer.
classification_heads: A list of optional head layers to transform on encoder
......@@ -166,6 +168,7 @@ class BertPretrainerV2(tf.keras.Model):
self,
num_masked_tokens: int,
encoder_network: tf.keras.Model,
mlm_activation=None,
mlm_initializer='glorot_uniform',
classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
name: str = 'bert',
......@@ -194,6 +197,7 @@ class BertPretrainerV2(tf.keras.Model):
num_predictions=num_masked_tokens,
input_width=sequence_output.shape[-1],
source_network=self.encoder_network,
activation=mlm_activation,
initializer=mlm_initializer,
name='masked_lm')
masked_lm_positions = copy.copy(self.masked_lm.inputs[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment