Commit ba627d4e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 359545082
parent 48b14968
...@@ -12,14 +12,19 @@ ...@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A script to export the BERT core model as a TF-Hub SavedModel.""" """A script to export BERT as a TF-Hub SavedModel.
This script is **DEPRECATED** for exporting BERT encoder models;
see the error message in by main() for details.
"""
from typing import Text
# Import libraries # Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Text
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import configs from official.nlp.bert import configs
...@@ -112,6 +117,14 @@ def export_bert_squad_tfhub(bert_config: configs.BertConfig, ...@@ -112,6 +117,14 @@ def export_bert_squad_tfhub(bert_config: configs.BertConfig,
def main(_): def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.model_type == "encoder": if FLAGS.model_type == "encoder":
deprecation_note = (
"nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
"models. Please switch to nlp/tools/export_tfhub for exporting BERT "
"(and other) encoders with dict inputs/outputs conforming to "
"https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
)
logging.error(deprecation_note)
print("\n\nNOTICE:", deprecation_note, "\n")
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case) FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
elif FLAGS.model_type == "squad": elif FLAGS.model_type == "squad":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment