Commit d2aac5d7 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix tf2_encoder_checkpoint_converter.py

PiperOrigin-RevId: 344192444
parent 38376959
...@@ -50,6 +50,7 @@ BERT_V2_NAME_REPLACEMENTS = ( ...@@ -50,6 +50,7 @@ BERT_V2_NAME_REPLACEMENTS = (
("output/dense", "output"), ("output/dense", "output"),
("output/LayerNorm", "output_layer_norm"), ("output/LayerNorm", "output_layer_norm"),
("pooler/dense", "pooler_transform"), ("pooler/dense", "pooler_transform"),
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"), ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"), ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights", ("cls/seq_relationship/output_weights",
......
...@@ -93,6 +93,8 @@ def _create_bert_pretrainer_model(cfg): ...@@ -93,6 +93,8 @@ def _create_bert_pretrainer_model(cfg):
mlm_activation=tf_utils.get_activation(cfg.hidden_act), mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range)) stddev=cfg.initializer_range))
# Makes sure the pretrainer variables are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer return pretrainer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment