ncf_test.py 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests NCF."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
Reed's avatar
Reed committed
22
import mock
23
24
25
26

import numpy as np
import tensorflow as tf

27
from absl import flags
Reed's avatar
Reed committed
28
from absl.testing import flagsaver
29
from official.recommendation import constants as rconst
Reed's avatar
Reed committed
30
from official.recommendation import data_preprocessing
31
from official.recommendation import neumf_model
32
from official.recommendation import ncf_main
33
34
35
36
from official.recommendation import stat_utils


NUM_TRAIN_NEG = 4
37
38
39


class NcfTest(tf.test.TestCase):
Reed's avatar
Reed committed
40
41
42
43
44
45

  @classmethod
  def setUpClass(cls):  # pylint: disable=invalid-name
    super(NcfTest, cls).setUpClass()
    ncf_main.define_ncf_flags()

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  def setUp(self):
    self.top_k_old = rconst.TOP_K
    self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
    rconst.NUM_EVAL_NEGATIVES = 2

  def tearDown(self):
    rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
    rconst.TOP_K = self.top_k_old

  def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
                            top_k=rconst.TOP_K, match_mlperf=False):
    rconst.TOP_K = top_k
    rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1

    g = tf.Graph()
    with g.as_default():
      logits = tf.convert_to_tensor(
          predicted_scores_by_user.reshape((-1, 1)), tf.float32)
      softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
                                  logits], axis=1)
      duplicate_mask = tf.convert_to_tensor(
          stat_utils.mask_duplicates(items_by_user, axis=1), tf.float32)

      metric_ops = neumf_model.compute_eval_loss_and_metrics(
          logits=logits, softmax_logits=softmax_logits,
          duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
          match_mlperf=match_mlperf).eval_metric_ops

      hr = metric_ops[rconst.HR_KEY]
      ndcg = metric_ops[rconst.NDCG_KEY]

      init = [tf.global_variables_initializer(),
              tf.local_variables_initializer()]

    with self.test_session(graph=g) as sess:
      sess.run(init)
      return sess.run([hr[1], ndcg[1]])



86
87
88
89
90
91
92
93
94
95
96
97
98
99
  def test_hit_rate_and_ndcg(self):
    # Test with no duplicate items
    predictions = np.array([
        [1., 2., 0.],  # In top 2
        [2., 1., 0.],  # In top 1
        [0., 2., 1.],  # In top 3
        [2., 3., 4.]   # In top 3
    ])
    items = np.array([
        [1, 2, 3],
        [2, 3, 1],
        [3, 2, 1],
        [2, 1, 3],
    ])
100
101

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
102
103
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
104
105

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
106
107
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
108
109

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
110
111
112
113
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)

114
115
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
                                          match_mlperf=True)
116
117
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
118
119
120

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
                                          match_mlperf=True)
121
122
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
123
124
125

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
                                          match_mlperf=True)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)

    # Test with duplicate items. In the MLPerf case, we treat the duplicates as
    # a single item. Otherwise, we treat the duplicates as separate items.
    predictions = np.array([
        [1., 2., 2., 3.],  # In top 4. MLPerf: In top 3
        [3., 1., 0., 2.],  # In top 1. MLPerf: In top 1
        [0., 2., 3., 2.],  # In top 4. MLPerf: In top 3
        [3., 2., 4., 2.]   # In top 2. MLPerf: In top 2
    ])
    items = np.array([
        [1, 2, 2, 3],
        [1, 2, 3, 4],
        [1, 2, 3, 2],
        [4, 3, 2, 1],
    ])
144
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
145
146
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
147
148

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
149
150
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
151
152

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
153
154
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
155
156

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4)
157
158
159
160
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(5)) / 4)

161
162
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
                                          match_mlperf=True)
163
164
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
165
166
167

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
                                          match_mlperf=True)
168
169
    self.assertAlmostEqual(hr, 2 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
170
171
172

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
                                          match_mlperf=True)
173
174
175
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)
176
177
178

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4,
                                          match_mlperf=True)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
                                  2 * math.log(2) / math.log(4)) / 4)

    # Test with duplicate items, where the predictions for the same item can
    # differ. In the MLPerf case, we should take the first prediction.
    predictions = np.array([
        [3., 2., 4., 4.],  # In top 3. MLPerf: In top 2
        [3., 4., 2., 4.],  # In top 3. MLPerf: In top 3
        [2., 3., 4., 1.],  # In top 3. MLPerf: In top 2
        [4., 3., 5., 2.]   # In top 2. MLPerf: In top 1
    ])
    items = np.array([
        [1, 2, 2, 3],
        [4, 3, 3, 2],
        [2, 1, 1, 1],
        [4, 2, 2, 1],
    ])
197
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
198
199
    self.assertAlmostEqual(hr, 0 / 4)
    self.assertAlmostEqual(ndcg, 0 / 4)
200
201

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
202
203
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3)) / 4)
204
205

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
206
207
208
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3) +
                                  3 * math.log(2) / math.log(4)) / 4)
209
210

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4)
211
212
213
214
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3) +
                                  3 * math.log(2) / math.log(4)) / 4)

215
216
    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
                                          match_mlperf=True)
217
218
    self.assertAlmostEqual(hr, 1 / 4)
    self.assertAlmostEqual(ndcg, 1 / 4)
219
220
221

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
                                          match_mlperf=True)
222
223
    self.assertAlmostEqual(hr, 3 / 4)
    self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3)) / 4)
224
225
226

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
                                          match_mlperf=True)
227
228
229
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3) +
                                  math.log(2) / math.log(4)) / 4)
230
231
232

    hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4,
                                          match_mlperf=True)
233
234
235
236
    self.assertAlmostEqual(hr, 4 / 4)
    self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3) +
                                  math.log(2) / math.log(4)) / 4)

Reed's avatar
Reed committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
  _BASE_END_TO_END_FLAGS = {
      "batch_size": 1024,
      "train_epochs": 1,
      "use_synthetic_data": True
  }

  @flagsaver.flagsaver(**_BASE_END_TO_END_FLAGS)
  @mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
  def test_end_to_end(self):
    ncf_main.main(None)

  @flagsaver.flagsaver(ml_perf=True, **_BASE_END_TO_END_FLAGS)
  @mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
  def test_end_to_end_mlperf(self):
    ncf_main.main(None)

253
254
255
256
257
258
259
  @flagsaver.flagsaver(use_estimator=False, **_BASE_END_TO_END_FLAGS)
  @mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
  def test_end_to_end_no_estimator(self):
    ncf_main.main(None)
    flags.FLAGS.ml_perf = True
    ncf_main.main(None)

260
  @flagsaver.flagsaver(use_estimator=False, **_BASE_END_TO_END_FLAGS)
Reed's avatar
Reed committed
261
262
  @mock.patch.object(data_preprocessing, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
  def test_end_to_end_while_loop(self):
263
264
265
266
    # We cannot set use_while_loop = True in the flagsaver constructor, because
    # if the flagsaver sets it to True before setting use_estimator to False,
    # the flag validator will throw an error.
    flags.FLAGS.use_while_loop = True
Reed's avatar
Reed committed
267
268
269
270
    ncf_main.main(None)
    flags.FLAGS.ml_perf = True
    ncf_main.main(None)

271
272
273
274

if __name__ == "__main__":
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.test.main()