Commit 04165099 authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 307712303
parent 330b34fe
...@@ -24,6 +24,7 @@ import os ...@@ -24,6 +24,7 @@ import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import performance from official.modeling import performance
from official.nlp import optimization from official.nlp import optimization
...@@ -404,6 +405,8 @@ def custom_main(custom_callbacks=None): ...@@ -404,6 +405,8 @@ def custom_main(custom_callbacks=None):
Args: Args:
custom_callbacks: list of tf.keras.Callbacks passed to training loop. custom_callbacks: list of tf.keras.Callbacks passed to training loop.
""" """
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -25,6 +25,7 @@ import time ...@@ -25,6 +25,7 @@ import time
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import configs as bert_configs from official.nlp.bert import configs as bert_configs
...@@ -91,6 +92,7 @@ def export_squad(model_export_path, input_meta_data): ...@@ -91,6 +92,7 @@ def export_squad(model_export_path, input_meta_data):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment