Unverified Commit dadc4a62 authored by Shining Sun's avatar Shining Sun Committed by GitHub
Browse files

Fix ncf test for keras (#6355)

* Fix ncf test for keras

* add a todo for batch_size and eval_batch_size for ncf keras

* lint fix

* fix typos

* Lint fix

* fix lint

* resolve pr comment

* resolve pr comment
parent ba0a6f60
...@@ -640,6 +640,9 @@ class DummyConstructor(threading.Thread): ...@@ -640,6 +640,9 @@ class DummyConstructor(threading.Thread):
def stop_loop(self): def stop_loop(self):
pass pass
def increment_request_epoch(self):
pass
@staticmethod @staticmethod
def make_input_fn(is_training): def make_input_fn(is_training):
"""Construct training input_fn that uses synthetic data.""" """Construct training input_fn that uses synthetic data."""
......
...@@ -151,8 +151,21 @@ def _get_keras_model(params): ...@@ -151,8 +151,21 @@ def _get_keras_model(params):
def run_ncf(_): def run_ncf(_):
"""Run NCF training and eval with Keras.""" """Run NCF training and eval with Keras."""
# TODO(seemuch): Support different train and eval batch sizes
if FLAGS.eval_batch_size != FLAGS.batch_size:
tf.logging.warning(
"The Keras implementation of NCF currently does not support batch_size "
"!= eval_batch_size ({} vs. {}). Overriding eval_batch_size to match "
"batch_size".format(FLAGS.eval_batch_size, FLAGS.batch_size)
)
FLAGS.eval_batch_size = FLAGS.batch_size
params = ncf_common.parse_flags(FLAGS) params = ncf_common.parse_flags(FLAGS)
# ncf_common rounds eval_batch_size (this is needed due to a reshape during
# eval). This carries over that rounding to batch_size as well.
params['batch_size'] = params['eval_batch_size']
num_users, num_items, num_train_steps, num_eval_steps, producer = ( num_users, num_items, num_train_steps, num_eval_steps, producer = (
ncf_common.get_inputs(params)) ncf_common.get_inputs(params))
......
...@@ -201,7 +201,6 @@ class NcfTest(tf.test.TestCase): ...@@ -201,7 +201,6 @@ class NcfTest(tf.test.TestCase):
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras(self): def test_end_to_end_keras(self):
self.skipTest("TODO: fix synthetic data with keras")
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
...@@ -209,7 +208,6 @@ class NcfTest(tf.test.TestCase): ...@@ -209,7 +208,6 @@ class NcfTest(tf.test.TestCase):
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_mlperf(self): def test_end_to_end_keras_mlperf(self):
self.skipTest("TODO: fix synthetic data with keras")
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment