ncf_test.py 10.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests NCF."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
22
import unittest
23
24
25
26

import numpy as np
import tensorflow as tf

27
from official.recommendation import constants as rconst
28
from official.recommendation import data_pipeline
29
from official.recommendation import neumf_model
Shining Sun's avatar
Shining Sun committed
30
31
from official.recommendation import ncf_common
from official.recommendation import ncf_estimator_main
32
from official.recommendation import ncf_keras_main
33
from official.utils.misc import keras_utils
34
from official.utils.testing import integration
35

36
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
37
38
39


NUM_TRAIN_NEG = 4
40
41
42


class NcfTest(tf.test.TestCase):
Reed's avatar
Reed committed
43
44
45
46

  @classmethod
  def setUpClass(cls):  # pylint: disable=invalid-name
    super(NcfTest, cls).setUpClass()
Shining Sun's avatar
Shining Sun committed
47
    ncf_common.define_ncf_flags()
Reed's avatar
Reed committed
48

49
50
51
52
53
54
55
56
57
  def setUp(self):
    self.top_k_old = rconst.TOP_K
    self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
    rconst.NUM_EVAL_NEGATIVES = 2

  def tearDown(self):
    rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
    rconst.TOP_K = self.top_k_old

58
  @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
59
60
61
62
  def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
                            top_k=rconst.TOP_K, match_mlperf=False):
    rconst.TOP_K = top_k
    rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1
63
64
65
66
67
68
69
    batch_size = items_by_user.shape[0]

    users = np.repeat(np.arange(batch_size)[:, np.newaxis],
                      rconst.NUM_EVAL_NEGATIVES + 1, axis=1)
    users, items, duplicate_mask = \
      data_pipeline.BaseDataConstructor._assemble_eval_batch(
          users, items_by_user[:, -1:], items_by_user[:, :-1], batch_size)
70
71
72
73
74
75
76

    g = tf.Graph()
    with g.as_default():
      logits = tf.convert_to_tensor(
          predicted_scores_by_user.reshape((-1, 1)), tf.float32)
      softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
                                  logits], axis=1)
77
      duplicate_mask = tf.convert_to_tensor(duplicate_mask, tf.float32)
78

Shining Sun's avatar
Shining Sun committed
79
      metric_ops = neumf_model._get_estimator_spec_with_metrics(
80
81
82
83
84
85
86
          logits=logits, softmax_logits=softmax_logits,
          duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
          match_mlperf=match_mlperf).eval_metric_ops

      hr = metric_ops[rconst.HR_KEY]
      ndcg = metric_ops[rconst.NDCG_KEY]

87
88
      init = [tf.compat.v1.global_variables_initializer(),
              tf.compat.v1.local_variables_initializer()]
89

90
    with self.session(graph=g) as sess:
91
92
93
      sess.run(init)
      return sess.run([hr[1], ndcg[1]])

94
95
96
  def test_hit_rate_and_ndcg(self):
    # Test with no duplicate items
    predictions = np.array([
97
98
99
100
        [2., 0., 1.],  # In top 2
        [1., 0., 2.],  # In top 1
        [2., 1., 0.],  # In top 3
        [3., 4., 2.]   # In top 3
101
102
103
    ])
    items = np.array([
        [2, 3, 1],
104
        [3, 1, 2],
105
        [2, 1, 3],
106
        [1, 3, 2],
107
    ])
108
109

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
110
111
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
112
113

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
114
115
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
116
117

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
118
119
120
121
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)

122
123
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
                                          match_mlperf=True)
124
125
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
126
127
128

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
                                          match_mlperf=True)
129
130
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
131
132
133

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
                                          match_mlperf=True)
134
135
136
137
138
139
140
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)

    # Test with duplicate items. In the MLPerf case, we treat the duplicates as
    # a single item. Otherwise, we treat the duplicates as separate items.
    predictions = np.array([
141
142
143
144
        [2., 2., 3., 1.],  # In top 4. MLPerf: In top 3
        [1., 0., 2., 3.],  # In top 1. MLPerf: In top 1
        [2., 3., 2., 0.],  # In top 4. MLPerf: In top 3
        [2., 4., 2., 3.]   # In top 2. MLPerf: In top 2
145
146
    ])
    items = np.array([
147
148
149
150
        [2, 2, 3, 1],
        [2, 3, 4, 1],
        [2, 3, 2, 1],
        [3, 2, 1, 4],
151
    ])
152
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
153
154
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
155
156

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
157
158
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
159
160

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
161
162
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
163
164

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4)
165
166
167
168
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(5)) / 4)

169
170
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
                                          match_mlperf=True)
171
172
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
173
174
175

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
                                          match_mlperf=True)
176
177
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
178
179
180

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
                                          match_mlperf=True)
181
182
183
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)
184
185
186

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4,
                                          match_mlperf=True)
187
188
189
190
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)

191
  _BASE_END_TO_END_FLAGS = ['-batch_size', '1044', '-train_epochs', '1']
192

193
  @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
194
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
195
196
  def test_end_to_end_estimator(self):
    integration.run_synthetic(
197
        ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
198
199
        extra_flags=self._BASE_END_TO_END_FLAGS)

200
  @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
201
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
202
203
  def test_end_to_end_estimator_mlperf(self):
    integration.run_synthetic(
204
        ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
205
        extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
Reed's avatar
Reed committed
206

207
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
208
  def test_end_to_end_keras_no_dist_strat(self):
209
    integration.run_synthetic(
210
        ncf_keras_main.main, tmp_root=self.get_temp_dir(),
211
        extra_flags=self._BASE_END_TO_END_FLAGS +
212
        ['-distribution_strategy', 'off'])
Reed's avatar
Reed committed
213

214
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
guptapriya's avatar
guptapriya committed
215
  @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
216
  def test_end_to_end_keras_dist_strat(self):
217
    integration.run_synthetic(
218
        ncf_keras_main.main, tmp_root=self.get_temp_dir(),
219
220
        extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])

221
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
guptapriya's avatar
guptapriya committed
222
  @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
223
224
225
226
227
  def test_end_to_end_keras_dist_strat_ctl(self):
    flags = (self._BASE_END_TO_END_FLAGS +
             ['-num_gpus', '0'] +
             ['-keras_use_ctl', 'True'])
    integration.run_synthetic(
228
        ncf_keras_main.main, tmp_root=self.get_temp_dir(),
guptapriya's avatar
guptapriya committed
229
        extra_flags=flags)
Reed's avatar
Reed committed
230

231
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
guptapriya's avatar
guptapriya committed
232
  @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
233
  def test_end_to_end_keras_1_gpu_dist_strat_fp16(self):
234
235
236
237
238
239
    if context.num_gpus() < 1:
      self.skipTest(
          "{} GPUs are not available for this test. {} GPUs are available".
          format(1, context.num_gpus()))

    integration.run_synthetic(
240
        ncf_keras_main.main, tmp_root=self.get_temp_dir(),
241
242
        extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
                                                   '--dtype', 'fp16'])
243

244
  @unittest.mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
guptapriya's avatar
guptapriya committed
245
  @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
246
247
248
249
250
251
252
253
254
255
256
257
  def test_end_to_end_keras_1_gpu_dist_strat_ctl_fp16(self):
    if context.num_gpus() < 1:
      self.skipTest(
          '{} GPUs are not available for this test. {} GPUs are available'.
          format(1, context.num_gpus()))

    integration.run_synthetic(
        ncf_keras_main.main, tmp_root=self.get_temp_dir(),
        extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1',
                                                   '--dtype', 'fp16',
                                                   '--keras_use_ctl'])

258
  @unittest.mock.patch.object(rconst, 'SYNTHETIC_BATCHES_PER_EPOCH', 100)
259
260
  @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
  def test_end_to_end_keras_2_gpu_fp16(self):
261
262
263
264
265
266
    if context.num_gpus() < 2:
      self.skipTest(
          "{} GPUs are not available for this test. {} GPUs are available".
          format(2, context.num_gpus()))

    integration.run_synthetic(
267
        ncf_keras_main.main, tmp_root=self.get_temp_dir(),
268
269
        extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2',
                                                   '--dtype', 'fp16'])
270
271
272

if __name__ == "__main__":
  tf.test.main()