Commit 6d7030f2 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move NCF estimator to R1.

PiperOrigin-RevId: 303897691
parent ad34b621
......@@ -15,6 +15,7 @@ in the previous releases.
| ----- | ----------- | --------- |
| [Gradient Boosted Trees](boosted_trees) | A gradient boosted trees model to classify higgs boson process from HIGGS dataset | [Link](https://en.wikipedia.org/wiki/Gradient_boosting) |
| [MNIST](mnist) | A basic model to classify digits from the MNIST dataset | [Link](http://yann.lecun.com/exdb/mnist/) |
| [NCF](ncf) | NCF Estimator implementation | [arXiv:1708.05031](https://arxiv.org/abs/1708.05031) |
| [ResNet](resnet) | A deep residual network for image recognition | [arXiv:1512.03385](https://arxiv.org/abs/1512.03385) |
| [Transformer](transformer) | A transformer model to translate the WMT English to German dataset | [arXiv:1706.03762](https://arxiv.org/abs/1706.03762) |
| [Wide & Deep Learning](wide_deep) | A model that combines a wide linear model and deep neural network for recommender systems | [arXiv:1606.07792](https://arxiv.org/abs/1606.07792) |
......@@ -23,18 +23,15 @@ import unittest
import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import neumf_model
from official.recommendation import ncf_common
from official.recommendation import ncf_estimator_main
from official.recommendation import ncf_keras_main
from official.recommendation import neumf_model
from official.utils.misc import keras_utils
from official.utils.testing import integration
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
NUM_TRAIN_NEG = 4
......@@ -190,20 +187,6 @@ class NcfTest(tf.test.TestCase):
_BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment