"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "8273c3f4500ac9c01e95f1d78fa3f5752aafdca7"
Commit 0b23ad50 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 315843901
parent e5641ef5
...@@ -175,7 +175,8 @@ def pretrain_model(bert_config, ...@@ -175,7 +175,8 @@ def pretrain_model(bert_config,
seq_length, seq_length,
max_predictions_per_seq, max_predictions_per_seq,
initializer=None, initializer=None,
use_next_sentence_label=True): use_next_sentence_label=True,
return_core_pretrainer_model=False):
"""Returns model to be used for pre-training. """Returns model to be used for pre-training.
Args: Args:
...@@ -185,10 +186,13 @@ def pretrain_model(bert_config, ...@@ -185,10 +186,13 @@ def pretrain_model(bert_config,
and use for pretraining. and use for pretraining.
initializer: Initializer for weights in BertPretrainer. initializer: Initializer for weights in BertPretrainer.
use_next_sentence_label: Whether to use the next sentence label. use_next_sentence_label: Whether to use the next sentence label.
return_core_pretrainer_model: Whether to also return the `BertPretrainer`
object.
Returns: Returns:
Pretraining model as well as core BERT submodel from which to save A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
weights after pretraining. save weights after pretraining, and (3) optional core `BertPretrainer`
object if argument `return_core_pretrainer_model` is True.
""" """
input_word_ids = tf.keras.layers.Input( input_word_ids = tf.keras.layers.Input(
shape=(seq_length,), name='input_word_ids', dtype=tf.int32) shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
...@@ -245,7 +249,10 @@ def pretrain_model(bert_config, ...@@ -245,7 +249,10 @@ def pretrain_model(bert_config,
inputs['next_sentence_labels'] = next_sentence_labels inputs['next_sentence_labels'] = next_sentence_labels
keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss) keras_model = tf.keras.Model(inputs=inputs, outputs=output_loss)
return keras_model, transformer_encoder if return_core_pretrainer_model:
return keras_model, transformer_encoder, pretrainer_model
else:
return keras_model, transformer_encoder
def squad_model(bert_config, def squad_model(bert_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment