Commit 104488e4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 304908157
parent 92e6cd93
......@@ -23,7 +23,7 @@ import time
# pylint: disable=g-bad-import-order
import numpy as np
from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.flags import core as flags_core
......
......@@ -104,7 +104,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver
def _train_squad(self, run_eagerly=False, ds_type='mirrored'):
"""Runs BERT SQuAD training. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type)
......@@ -118,7 +117,6 @@ class BertSquadBenchmarkBase(benchmark_utils.BertBenchmarkBase):
@flagsaver.flagsaver
def _evaluate_squad(self, ds_type='mirrored'):
"""Runs BERT SQuAD evaluation. Uses mirrored strategy by default."""
assert tf.version.VERSION.startswith('2.')
self._init_gpu_and_data_threads()
input_meta_data = self._read_input_meta_data_from_file()
strategy = self._get_distribution_strategy(ds_type)
......
......@@ -32,7 +32,8 @@ class KerasBenchmark(PerfZeroBenchmark):
default_flags=None,
flag_methods=None,
tpu=None):
assert tf.version.VERSION.startswith('2.')
# Due to xla legacy benchmark.
tf.compat.v1.enable_v2_behavior()
super(KerasBenchmark, self).__init__(
output_dir=output_dir,
default_flags=default_flags,
......
......@@ -50,7 +50,6 @@ class NCFKerasBenchmarkBase(tf.test.Benchmark):
def _setup(self):
"""Sets up and resets flags before each test."""
assert tf.version.VERSION.startswith('2.')
logging.set_verbosity(logging.INFO)
if NCFKerasBenchmarkBase.local_flags is None:
ncf_common.define_ncf_flags()
......
......@@ -20,7 +20,7 @@ import functools
import time
from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow as tf
import tensorflow_hub as hub
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......
......@@ -44,7 +44,8 @@ class TransformerBenchmark(PerfZeroBenchmark):
def __init__(self, output_dir=None, default_flags=None, root_data_dir=None,
flag_methods=None):
assert tf.version.VERSION.startswith('2.')
# Due to xla legacy benchmark.
tf.compat.v1.enable_v2_behavior()
root_data_dir = root_data_dir if root_data_dir else ''
self.train_data_dir = os.path.join(root_data_dir,
......
......@@ -29,7 +29,7 @@ import os
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
from official.recommendation import constants as rconst
......
......@@ -251,6 +251,5 @@ def main(argv):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.config.set_soft_device_placement(True)
app.run(main)
......@@ -59,7 +59,7 @@ def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
assert tf.version.VERSION.startswith('2.')
tf.enable_v2_behavior()
export_tfhub(FLAGS.model_path, FLAGS.export_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment