"verl/workers/vscode:/vscode.git/clone" did not exist on "c132cbcbe79197593a8377d08f4b13a172a7b464"
Commit a12cec09 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 300477605
parent e8afeaee
......@@ -78,7 +78,6 @@ def export_albert_tfhub(albert_config: configs.AlbertConfig,
def main(_):
assert tf.version.VERSION.startswith('2.')
albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
......
......@@ -86,5 +86,4 @@ class ExportAlbertTfhubTest(tf.test.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -33,7 +33,6 @@ FLAGS = flags.FLAGS
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......
......@@ -80,7 +80,6 @@ def export_squad(model_export_path, input_meta_data):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......
......@@ -122,7 +122,6 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_):
assert tf.version.VERSION.startswith('2.')
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
......
......@@ -77,7 +77,6 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
def main(_):
assert tf.version.VERSION.startswith('2.')
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
FLAGS.vocab_file)
......
......@@ -84,5 +84,4 @@ class ExportTfhubTest(tf.test.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -403,7 +403,6 @@ def run_bert(strategy,
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......
......@@ -159,7 +159,6 @@ def run_bert_pretrain(strategy):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
......
......@@ -76,7 +76,6 @@ def export_squad(model_export_path, input_meta_data):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......
......@@ -98,7 +98,7 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_):
assert tf.version.VERSION.startswith('2.')
tf.enable_v2_behavior()
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
......
......@@ -186,5 +186,4 @@ class TransformerLayerTest(keras_parameterized.TestCase):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -19,7 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v2 as tf
from official.nlp.modeling import networks
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v2 as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
......
......@@ -20,7 +20,7 @@ from __future__ import division
from __future__ import print_function
import copy
import tensorflow as tf
import tensorflow.compat.v2 as tf
from official.nlp.modeling import networks
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v2 as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
......
......@@ -19,7 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v2 as tf
from official.nlp.modeling import networks
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.compat.v2 as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
......
......@@ -171,5 +171,4 @@ class AlbertTransformerEncoderTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
......@@ -626,5 +626,4 @@ class EncoderScaffoldHiddenInstanceTest(keras_parameterized.TestCase):
if __name__ == "__main__":
assert tf.version.VERSION.startswith('2.')
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment