Commit a12cec09 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300477605
parent e8afeaee
...@@ -78,7 +78,6 @@ def export_albert_tfhub(albert_config: configs.AlbertConfig, ...@@ -78,7 +78,6 @@ def export_albert_tfhub(albert_config: configs.AlbertConfig,
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.')
albert_config = configs.AlbertConfig.from_json_file( albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file) FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path, export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
......
...@@ -86,5 +86,4 @@ class ExportAlbertTfhubTest(tf.test.TestCase): ...@@ -86,5 +86,4 @@ class ExportAlbertTfhubTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -33,7 +33,6 @@ FLAGS = flags.FLAGS ...@@ -33,7 +33,6 @@ FLAGS = flags.FLAGS
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -80,7 +80,6 @@ def export_squad(model_export_path, input_meta_data): ...@@ -80,7 +80,6 @@ def export_squad(model_export_path, input_meta_data):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -122,7 +122,6 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -122,7 +122,6 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.')
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file) albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
......
...@@ -77,7 +77,6 @@ def export_bert_tfhub(bert_config: configs.BertConfig, ...@@ -77,7 +77,6 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.')
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file) FLAGS.vocab_file)
......
...@@ -84,5 +84,4 @@ class ExportTfhubTest(tf.test.TestCase): ...@@ -84,5 +84,4 @@ class ExportTfhubTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -403,7 +403,6 @@ def run_bert(strategy, ...@@ -403,7 +403,6 @@ def run_bert(strategy,
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -159,7 +159,6 @@ def run_bert_pretrain(strategy): ...@@ -159,7 +159,6 @@ def run_bert_pretrain(strategy):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir: if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/' FLAGS.model_dir = '/tmp/bert20/'
......
...@@ -76,7 +76,6 @@ def export_squad(model_export_path, input_meta_data): ...@@ -76,7 +76,6 @@ def export_squad(model_export_path, input_meta_data):
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader: with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8')) input_meta_data = json.loads(reader.read().decode('utf-8'))
......
...@@ -98,7 +98,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -98,7 +98,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_): def main(_):
assert tf.version.VERSION.startswith('2.') tf.enable_v2_behavior()
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
......
...@@ -186,5 +186,4 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -186,5 +186,4 @@ class TransformerLayerTest(keras_parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -19,7 +19,7 @@ from __future__ import division ...@@ -19,7 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow.compat.v2 as tf
from official.nlp.modeling import networks from official.nlp.modeling import networks
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow.compat.v2 as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks from official.nlp.modeling import networks
......
...@@ -20,7 +20,7 @@ from __future__ import division ...@@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import tensorflow as tf import tensorflow.compat.v2 as tf
from official.nlp.modeling import networks from official.nlp.modeling import networks
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow.compat.v2 as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks from official.nlp.modeling import networks
......
...@@ -19,7 +19,7 @@ from __future__ import division ...@@ -19,7 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow.compat.v2 as tf
from official.nlp.modeling import networks from official.nlp.modeling import networks
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow.compat.v2 as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks from official.nlp.modeling import networks
......
...@@ -171,5 +171,4 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase): ...@@ -171,5 +171,4 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
...@@ -626,5 +626,4 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase): ...@@ -626,5 +626,4 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment