Commit c2666cea authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

[Clean up] Remove enable_eager in the session config: Model garden is TF2 only now.

Remove is_v2_0

PiperOrigin-RevId: 312336907
parent 4ec2ee97
...@@ -22,7 +22,7 @@ import os ...@@ -22,7 +22,7 @@ import os
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow.compat.v1 as tf
from official.r1.utils.logs import logger from official.r1.utils.logs import logger
from official.r1.wide_deep import movielens_dataset from official.r1.wide_deep import movielens_dataset
from official.r1.wide_deep import wide_deep_run_loop from official.r1.wide_deep import wide_deep_run_loop
......
...@@ -18,13 +18,11 @@ from __future__ import division ...@@ -18,13 +18,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
import numpy as np import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow.compat.v1 as tf
from official.recommendation import movielens from official.recommendation import movielens
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
from official.r1.wide_deep import movielens_dataset from official.r1.wide_deep import movielens_dataset
from official.r1.wide_deep import movielens_main from official.r1.wide_deep import movielens_main
...@@ -85,7 +83,6 @@ class BaseTest(tf.test.TestCase): ...@@ -85,7 +83,6 @@ class BaseTest(tf.test.TestCase):
with tf.io.gfile.GFile(self.item_csv, "w") as f: with tf.io.gfile.GFile(self.item_csv, "w") as f:
f.write(TEST_ITEM_DATA) f.write(TEST_ITEM_DATA)
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_input_fn(self): def test_input_fn(self):
train_input_fn, _, _ = movielens_dataset.construct_input_fns( train_input_fn, _, _ = movielens_dataset.construct_input_fns(
dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1) dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1)
...@@ -103,7 +100,6 @@ class BaseTest(tf.test.TestCase): ...@@ -103,7 +100,6 @@ class BaseTest(tf.test.TestCase):
self.assertAllClose(labels[0], [1.0]) self.assertAllClose(labels[0], [1.0])
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
def test_end_to_end_deep(self): def test_end_to_end_deep(self):
integration.run_synthetic( integration.run_synthetic(
main=movielens_main.main, tmp_root=self.temp_dir, main=movielens_main.main, tmp_root=self.temp_dir,
...@@ -117,4 +113,5 @@ class BaseTest(tf.test.TestCase): ...@@ -117,4 +113,5 @@ class BaseTest(tf.test.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
tf.disable_eager_execution()
tf.test.main() tf.test.main()
...@@ -23,7 +23,7 @@ import shutil ...@@ -23,7 +23,7 @@ import shutil
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow.compat.v1 as tf
from official.r1.utils.logs import hooks_helper from official.r1.utils.logs import hooks_helper
from official.r1.utils.logs import logger from official.r1.utils.logs import logger
......
...@@ -31,7 +31,6 @@ from official.recommendation import constants as rconst ...@@ -31,7 +31,6 @@ from official.recommendation import constants as rconst
from official.recommendation import data_preprocessing from official.recommendation import data_preprocessing
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import popen_helper from official.recommendation import popen_helper
from official.utils.misc import keras_utils
DATASET = "ml-test" DATASET = "ml-test"
...@@ -59,8 +58,7 @@ def mock_download(*args, **kwargs): ...@@ -59,8 +58,7 @@ def mock_download(*args, **kwargs):
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
def setUp(self): def setUp(self):
if keras_utils.is_v2_0: tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_eager_execution()
self.temp_data_dir = self.get_temp_dir() self.temp_data_dir = self.get_temp_dir()
ratings_folder = os.path.join(self.temp_data_dir, DATASET) ratings_folder = os.path.join(self.temp_data_dir, DATASET)
tf.io.gfile.makedirs(ratings_folder) tf.io.gfile.makedirs(ratings_folder)
......
...@@ -220,14 +220,6 @@ def run_ncf(_): ...@@ -220,14 +220,6 @@ def run_ncf(_):
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
params["distribute_strategy"] = strategy params["distribute_strategy"] = strategy
if not keras_utils.is_v2_0() and strategy is not None:
logging.error("NCF Keras only works with distribution strategy in TF 2.0")
return
if (params["keras_use_ctl"] and (
not keras_utils.is_v2_0() or strategy is None)):
logging.error(
"Custom training loop only works with tensorflow 2.0 and dist strat.")
return
if params["use_tpu"] and not params["keras_use_ctl"]: if params["use_tpu"] and not params["keras_use_ctl"]:
logging.error("Custom training loop must be used when using TPUStrategy.") logging.error("Custom training loop must be used when using TPUStrategy.")
return return
......
...@@ -18,21 +18,15 @@ from __future__ import absolute_import ...@@ -18,21 +18,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import unittest import unittest
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_keras_main from official.recommendation import ncf_keras_main
from official.recommendation import neumf_model
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
NUM_TRAIN_NEG = 4 NUM_TRAIN_NEG = 4
...@@ -52,139 +46,6 @@ class NcfTest(tf.test.TestCase): ...@@ -52,139 +46,6 @@ class NcfTest(tf.test.TestCase):
rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
rconst.TOP_K = self.top_k_old rconst.TOP_K = self.top_k_old
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
top_k=rconst.TOP_K, match_mlperf=False):
rconst.TOP_K = top_k
rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1
batch_size = items_by_user.shape[0]
users = np.repeat(np.arange(batch_size)[:, np.newaxis],
rconst.NUM_EVAL_NEGATIVES + 1, axis=1)
users, items, duplicate_mask = \
data_pipeline.BaseDataConstructor._assemble_eval_batch(
users, items_by_user[:, -1:], items_by_user[:, :-1], batch_size)
g = tf.Graph()
with g.as_default():
logits = tf.convert_to_tensor(
predicted_scores_by_user.reshape((-1, 1)), tf.float32)
softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
logits], axis=1)
duplicate_mask = tf.convert_to_tensor(duplicate_mask, tf.float32)
metric_ops = neumf_model._get_estimator_spec_with_metrics(
logits=logits, softmax_logits=softmax_logits,
duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
match_mlperf=match_mlperf).eval_metric_ops
hr = metric_ops[rconst.HR_KEY]
ndcg = metric_ops[rconst.NDCG_KEY]
init = [tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer()]
with self.session(graph=g) as sess:
sess.run(init)
return sess.run([hr[1], ndcg[1]])
def test_hit_rate_and_ndcg(self):
# Test with no duplicate items
predictions = np.array([
[2., 0., 1.], # In top 2
[1., 0., 2.], # In top 1
[2., 1., 0.], # In top 3
[3., 4., 2.] # In top 3
])
items = np.array([
[2, 3, 1],
[3, 1, 2],
[2, 1, 3],
[1, 3, 2],
])
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
match_mlperf=True)
self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True)
self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4)
# Test with duplicate items. In the MLPerf case, we treat the duplicates as
# a single item. Otherwise, we treat the duplicates as separate items.
predictions = np.array([
[2., 2., 3., 1.], # In top 4. MLPerf: In top 3
[1., 0., 2., 3.], # In top 1. MLPerf: In top 1
[2., 3., 2., 0.], # In top 4. MLPerf: In top 3
[2., 4., 2., 3.] # In top 2. MLPerf: In top 2
])
items = np.array([
[2, 2, 3, 1],
[2, 3, 4, 1],
[2, 3, 2, 1],
[3, 2, 1, 4],
])
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4)
self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(5)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
match_mlperf=True)
self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True)
self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4)
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -195,14 +56,12 @@ class NcfTest(tf.test.TestCase): ...@@ -195,14 +56,12 @@ class NcfTest(tf.test.TestCase):
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat_ctl(self): def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS + flags = (self._BASE_END_TO_END_FLAGS +
['-num_gpus', '0'] + ['-num_gpus', '0'] +
...@@ -212,7 +71,6 @@ class NcfTest(tf.test.TestCase): ...@@ -212,7 +71,6 @@ class NcfTest(tf.test.TestCase):
extra_flags=flags) extra_flags=flags)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
self.skipTest( self.skipTest(
...@@ -225,7 +83,6 @@ class NcfTest(tf.test.TestCase): ...@@ -225,7 +83,6 @@ class NcfTest(tf.test.TestCase):
'--dtype', 'fp16']) '--dtype', 'fp16'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
self.skipTest( self.skipTest(
...@@ -239,7 +96,6 @@ class NcfTest(tf.test.TestCase): ...@@ -239,7 +96,6 @@ class NcfTest(tf.test.TestCase):
'--keras_use_ctl']) '--keras_use_ctl'])
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_2_gpu_fp16(self): def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
self.skipTest( self.skipTest(
......
...@@ -24,7 +24,6 @@ import time ...@@ -24,7 +24,6 @@ import time
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow.python import tf2
class BatchTimestamp(object): class BatchTimestamp(object):
...@@ -150,39 +149,13 @@ class SimpleCheckpoint(tf.keras.callbacks.Callback): ...@@ -150,39 +149,13 @@ class SimpleCheckpoint(tf.keras.callbacks.Callback):
self.checkpoint_manager.save(checkpoint_number=step_counter) self.checkpoint_manager.save(checkpoint_number=step_counter)
def set_session_config(enable_eager=False, def set_session_config(enable_xla=False):
enable_xla=False):
"""Sets the session config.""" """Sets the session config."""
if is_v2_0():
set_config_v2(enable_xla=enable_xla)
else:
config = get_config_proto_v1(enable_xla=enable_xla)
if enable_eager:
tf.compat.v1.enable_eager_execution(config=config)
else:
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)
def get_config_proto_v1(enable_xla=False):
"""Return config proto according to flag settings, or None to use default."""
config = None
if enable_xla:
config = tf.compat.v1.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
tf.OptimizerOptions.ON_2)
return config
def set_config_v2(enable_xla=False):
"""Config eager context according to flag values using TF 2.0 API."""
if enable_xla: if enable_xla:
tf.config.optimizer.set_jit(True) tf.config.optimizer.set_jit(True)
# TODO(hongkuny): remove set_config_v2 globally.
def is_v2_0(): set_config_v2 = set_session_config
"""Returns true if using tf 2.0."""
return tf2.enabled()
def set_gpu_thread_mode_and_count(gpu_thread_mode, def set_gpu_thread_mode_and_count(gpu_thread_mode,
......
...@@ -20,7 +20,6 @@ from __future__ import print_function ...@@ -20,7 +20,6 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
...@@ -29,8 +28,7 @@ class PastStopThresholdTest(tf.test.TestCase): ...@@ -29,8 +28,7 @@ class PastStopThresholdTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(PastStopThresholdTest, self).setUp() super(PastStopThresholdTest, self).setUp()
if keras_utils.is_v2_0: tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_eager_execution()
def test_past_stop_threshold(self): def test_past_stop_threshold(self):
"""Tests for normal operating conditions.""" """Tests for normal operating conditions."""
......
...@@ -25,7 +25,6 @@ import tensorflow as tf ...@@ -25,7 +25,6 @@ import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import strategy_combinations
from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
from official.vision.image_classification import mnist_main from official.vision.image_classification import mnist_main
...@@ -57,8 +56,6 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): ...@@ -57,8 +56,6 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(eager_strategy_combinations()) @combinations.generate(eager_strategy_combinations())
def test_end_to_end(self, distribution): def test_end_to_end(self, distribution):
"""Test Keras MNIST model with `strategy`.""" """Test Keras MNIST model with `strategy`."""
config = keras_utils.get_config_proto_v1()
tf.compat.v1.enable_eager_execution(config=config)
extra_flags = [ extra_flags = [
"-train_epochs", "1", "-train_epochs", "1",
......
...@@ -109,7 +109,6 @@ def run(flags_obj): ...@@ -109,7 +109,6 @@ def run(flags_obj):
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config( keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment