Commit 4b8f80c3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use unittest.mock as we are py3 now

PiperOrigin-RevId: 296944580
parent 02d78796
...@@ -21,7 +21,6 @@ from __future__ import print_function ...@@ -21,7 +21,6 @@ from __future__ import print_function
import math import math
import unittest import unittest
import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -192,34 +191,34 @@ class NcfTest(tf.test.TestCase): ...@@ -192,34 +191,34 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1'] _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self): def test_end_to_end_estimator(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS) extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self): def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self): def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat_ctl(self): def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS + flags = (self._BASE_END_TO_END_FLAGS +
...@@ -229,7 +228,7 @@ class NcfTest(tf.test.TestCase): ...@@ -229,7 +228,7 @@ class NcfTest(tf.test.TestCase):
ncf_keras_main.main, tmp_root=self.get_temp_dir(), ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags) extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
...@@ -242,7 +241,7 @@ class NcfTest(tf.test.TestCase): ...@@ -242,7 +241,7 @@ class NcfTest(tf.test.TestCase):
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1', extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
'--dtype', 'fp16']) '--dtype', 'fp16'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self): def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1: if context.num_gpus() < 1:
...@@ -256,7 +255,7 @@ class NcfTest(tf.test.TestCase): ...@@ -256,7 +255,7 @@ class NcfTest(tf.test.TestCase):
'--dtype', 'fp16', '--dtype', 'fp16',
'--keras_use_ctl']) '--keras_use_ctl'])
@mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100) @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_2_gpu_fp16(self): def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2: if context.num_gpus() < 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment