Commit 4b8f80c3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Use unittest.mock as we are py3 now

PiperOrigin-RevId: 296944580
parent 02d78796
......@@ -21,7 +21,6 @@ from __future__ import print_function
import math
import unittest
import mock
import numpy as np
import tensorflow as tf
......@@ -192,34 +191,34 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat_ctl(self):
flags = (self._BASE_END_TO_END_FLAGS +
......@@ -229,7 +228,7 @@ class NcfTest(tf.test.TestCase):
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
if context.num_gpus() < 1:
......@@ -242,7 +241,7 @@ class NcfTest(tf.test.TestCase):
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
'--dtype', 'fp16'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
if context.num_gpus() < 1:
......@@ -256,7 +255,7 @@ class NcfTest(tf.test.TestCase):
'--dtype', 'fp16',
'--keras_use_ctl'])
@mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_2_gpu_fp16(self):
if context.num_gpus() < 2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment