Commit 0b23ad50 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 315843901
parent e5641ef5
......@@ -175,7 +175,8 @@ def pretrain_model(bert_config,
seq_length,
max_predictions_per_seq,
initializer=None,
use_next_sentence_label=True):
use_next_sentence_label=True,
return_core_pretrainer_model=False):
"""Returns model to be used for pre-training.
Args:
......@@ -185,10 +186,13 @@ def pretrain_model(bert_config,
and use for pretraining.
initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label.
return_core_pretrainer_model: Whether to also return the `BertPretrainer`
object.
Returns:
Pretraining model as well as core BERT submodel from which to save
weights after pretraining.
A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
save weights after pretraining, and (3) optional core `BertPretrainer`
object if argument `return_core_pretrainer_model` is True.
"""
input_word_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
......@@ -245,6 +249,9 @@ def pretrain_model(bert_config,
inputs['next_sentence_labels'] = next_sentence_labels
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
if return_core_pretrainer_model:
return keras_model, transformer_encoder, pretrainer_model
else:
return keras_model, transformer_encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment