Commit 104488e4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304908157
parent 92e6cd93
...@@ -23,7 +23,7 @@ import time ...@@ -23,7 +23,7 @@ import time
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import numpy as np import numpy as np
from absl import flags from absl import flags
import tensorflow.compat.v2 as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
......
...@@ -104,7 +104,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -104,7 +104,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver @flagsaver.flagsaver
def _train_squad(self, run_eagerly=False, ds_type='mirrored'): def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
"""Runs BERT SQuAD training. Uses mirrored strategy by default.""" """Runs BERT SQuAD training. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads() self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type) strategy = self._get_distribution_strategy(ds_type)
...@@ -118,7 +117,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -118,7 +117,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver @flagsaver.flagsaver
def _evaluate_squad(self, ds_type='mirrored'): def _evaluate_squad(self, ds_type='mirrored'):
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default.""" """Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads() self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file() input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type) strategy = self._get_distribution_strategy(ds_type)
......
...@@ -32,7 +32,8 @@ class KerasBenchmark(PerfZeroBenchmark): ...@@ -32,7 +32,8 @@ class KerasBenchmark(PerfZeroBenchmark):
default_flags=None, default_flags=None,
flag_methods=None, flag_methods=None,
tpu=None): tpu=None):
assert tf.version.VERSION.startswith('2.') # Due to xla legacy benchmark.
tf.compat.v1.enable_v2_behavior()
super(KerasBenchmark, self).__init__( super(KerasBenchmark, self).__init__(
output_dir=output_dir, output_dir=output_dir,
default_flags=default_flags, default_flags=default_flags,
......
...@@ -50,7 +50,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark): ...@@ -50,7 +50,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def _setup(self): def _setup(self):
"""Sets up and resets flags before each test.""" """Sets up and resets flags before each test."""
assert tf.version.VERSION.startswith('2.')
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
if NCFKerasBenchmarkBase.local_flags is None: if NCFKerasBenchmarkBase.local_flags is None:
ncf_common.define_ncf_flags() ncf_common.define_ncf_flags()
......
...@@ -20,7 +20,7 @@ import functools ...@@ -20,7 +20,7 @@ import functools
import time import time
from absl import flags from absl import flags
import tensorflow.compat.v2 as tf import tensorflow as tf
import tensorflow_hub as hub import tensorflow_hub as hub
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......
...@@ -44,7 +44,8 @@ class TransformerBenchmark(PerfZeroBenchmark): ...@@ -44,7 +44,8 @@ class TransformerBenchmark(PerfZeroBenchmark):
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None, def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None): flag_methods=None):
assert tf.version.VERSION.startswith('2.') # Due to xla legacy benchmark.
tf.compat.v1.enable_v2_behavior()
root_data_dir = root_data_dir if root_data_dir else '' root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir, self.train_data_dir = os.path.join(root_data_dir,
......
...@@ -29,7 +29,7 @@ import os ...@@ -29,7 +29,7 @@ import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
......
...@@ -251,6 +251,5 @@ def main(argv): ...@@ -251,6 +251,5 @@ def main(argv):
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.config.set_soft_device_placement(True) tf.config.set_soft_device_placement(True)
app.run(main) app.run(main)
...@@ -59,7 +59,7 @@ def main(argv): ...@@ -59,7 +59,7 @@ def main(argv):
if len(argv) > 1: if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.") raise app.UsageError("Too many command-line arguments.")
assert tf.version.VERSION.startswith('2.') tf.enable_v2_behavior()
export_tfhub(FLAGS.model_path, FLAGS.export_path) export_tfhub(FLAGS.model_path, FLAGS.export_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment