Commit c494582f authored by Taylor Robie's avatar Taylor Robie Committed by Reed
Browse files

Move evaluation to .evaluate() (#5413)

* move evaluation from numpy to tensorflow

fix syntax error

don't use sigmoid to convert logits. there is too much precision loss.

WIP: add logit metrics

continue refactor of NCF evaluation

fix syntax error

fix bugs in eval loss calculation

fix eval loss reweighting

remove numpy based metric calculations

fix logging hooks

fix sigmoid to softmax bug

fix comment

catch rare PIPE error and address some PR comments

* fix metric test and address PR comments

* delint and fix python2

* fix test and address PR comments

* extend eval to TPUs
parent f3be93a7
...@@ -52,6 +52,12 @@ MIN_NUM_RATINGS = 20 ...@@ -52,6 +52,12 @@ MIN_NUM_RATINGS = 20
# when performing evaluation. # when performing evaluation.
NUM_EVAL_NEGATIVES = 999 NUM_EVAL_NEGATIVES = 999
# keys for evaluation metrics
TOP_K = 10 # Top-k list for evaluation
HR_KEY = "HR"
NDCG_KEY = "NDCG"
DUPLICATE_MASK = "duplicate_mask"
# ============================================================================== # ==============================================================================
# == Subprocess Data Generation ================================================ # == Subprocess Data Generation ================================================
# ============================================================================== # ==============================================================================
......
...@@ -124,7 +124,7 @@ def _process_shard(args): ...@@ -124,7 +124,7 @@ def _process_shard(args):
return users_out, items_out, labels_out return users_out, items_out, labels_out
def _construct_record(users, items, labels=None): def _construct_record(users, items, labels=None, dupe_mask=None):
"""Convert NumPy arrays into a TFRecords entry.""" """Convert NumPy arrays into a TFRecords entry."""
feature_dict = { feature_dict = {
movielens.USER_COLUMN: tf.train.Feature( movielens.USER_COLUMN: tf.train.Feature(
...@@ -136,6 +136,10 @@ def _construct_record(users, items, labels=None): ...@@ -136,6 +136,10 @@ def _construct_record(users, items, labels=None):
feature_dict["labels"] = tf.train.Feature( feature_dict["labels"] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[memoryview(labels).tobytes()])) bytes_list=tf.train.BytesList(value=[memoryview(labels).tobytes()]))
if dupe_mask is not None:
feature_dict[rconst.DUPLICATE_MASK] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[memoryview(dupe_mask).tobytes()]))
return tf.train.Example( return tf.train.Example(
features=tf.train.Features(feature=feature_dict)).SerializeToString() features=tf.train.Features(feature=feature_dict)).SerializeToString()
...@@ -305,6 +309,9 @@ def _construct_training_records( ...@@ -305,6 +309,9 @@ def _construct_training_records(
def _construct_eval_record(cache_paths, eval_batch_size): def _construct_eval_record(cache_paths, eval_batch_size):
"""Convert Eval data to a single TFRecords file.""" """Convert Eval data to a single TFRecords file."""
# Later logic assumes that all items for a given user are in the same batch.
assert not eval_batch_size % (rconst.NUM_EVAL_NEGATIVES + 1)
log_msg("Beginning construction of eval TFRecords file.") log_msg("Beginning construction of eval TFRecords file.")
raw_fpath = cache_paths.eval_raw_file raw_fpath = cache_paths.eval_raw_file
intermediate_fpath = cache_paths.eval_record_template_temp intermediate_fpath = cache_paths.eval_record_template_temp
...@@ -332,9 +339,16 @@ def _construct_eval_record(cache_paths, eval_batch_size): ...@@ -332,9 +339,16 @@ def _construct_eval_record(cache_paths, eval_batch_size):
num_batches = users.shape[0] num_batches = users.shape[0]
with tf.python_io.TFRecordWriter(intermediate_fpath) as writer: with tf.python_io.TFRecordWriter(intermediate_fpath) as writer:
for i in range(num_batches): for i in range(num_batches):
batch_users = users[i, :]
batch_items = items[i, :]
dupe_mask = stat_utils.mask_duplicates(
batch_items.reshape(-1, rconst.NUM_EVAL_NEGATIVES + 1),
axis=1).flatten().astype(np.int8)
batch_bytes = _construct_record( batch_bytes = _construct_record(
users=users[i, :], users=batch_users,
items=items[i, :] items=batch_items,
dupe_mask=dupe_mask
) )
writer.write(batch_bytes) writer.write(batch_bytes)
tf.gfile.Rename(intermediate_fpath, dest_fpath) tf.gfile.Rename(intermediate_fpath, dest_fpath)
......
...@@ -27,6 +27,7 @@ import json ...@@ -27,6 +27,7 @@ import json
import os import os
import pickle import pickle
import signal import signal
import socket
import subprocess import subprocess
import time import time
import timeit import timeit
...@@ -399,10 +400,14 @@ def _shutdown(proc): ...@@ -399,10 +400,14 @@ def _shutdown(proc):
"""Convenience function to cleanly shut down async generation process.""" """Convenience function to cleanly shut down async generation process."""
tf.logging.info("Shutting down train data creation subprocess.") tf.logging.info("Shutting down train data creation subprocess.")
proc.send_signal(signal.SIGINT) try:
time.sleep(1) proc.send_signal(signal.SIGINT)
if proc.returncode is not None: time.sleep(1)
return # SIGINT was handled successfully within 1 sec if proc.returncode is not None:
return # SIGINT was handled successfully within 1 sec
except socket.error:
pass
# Otherwise another second of grace period and then forcibly kill the process. # Otherwise another second of grace period and then forcibly kill the process.
time.sleep(1) time.sleep(1)
...@@ -493,6 +498,8 @@ def make_deserialize(params, batch_size, training=False): ...@@ -493,6 +498,8 @@ def make_deserialize(params, batch_size, training=False):
} }
if training: if training:
feature_map["labels"] = tf.FixedLenFeature([], dtype=tf.string) feature_map["labels"] = tf.FixedLenFeature([], dtype=tf.string)
else:
feature_map[rconst.DUPLICATE_MASK] = tf.FixedLenFeature([], dtype=tf.string)
def deserialize(examples_serialized): def deserialize(examples_serialized):
"""Called by Dataset.map() to convert batches of records to tensors.""" """Called by Dataset.map() to convert batches of records to tensors."""
...@@ -506,13 +513,17 @@ def make_deserialize(params, batch_size, training=False): ...@@ -506,13 +513,17 @@ def make_deserialize(params, batch_size, training=False):
items = tf.cast(items, tf.int32) # TPU doesn't allow uint16 infeed. items = tf.cast(items, tf.int32) # TPU doesn't allow uint16 infeed.
if not training: if not training:
dupe_mask = tf.reshape(tf.cast(tf.decode_raw(
features[rconst.DUPLICATE_MASK], tf.int8), tf.bool), (batch_size,))
return { return {
movielens.USER_COLUMN: users, movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
rconst.DUPLICATE_MASK: dupe_mask,
} }
labels = tf.reshape(tf.cast(tf.decode_raw( labels = tf.reshape(tf.cast(tf.decode_raw(
features["labels"], tf.int8), tf.bool), (batch_size,)) features["labels"], tf.int8), tf.bool), (batch_size,))
return { return {
movielens.USER_COLUMN: users, movielens.USER_COLUMN: users,
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
......
...@@ -36,6 +36,7 @@ NUM_USERS = 1000 ...@@ -36,6 +36,7 @@ NUM_USERS = 1000
NUM_ITEMS = 2000 NUM_ITEMS = 2000
NUM_PTS = 50000 NUM_PTS = 50000
BATCH_SIZE = 2048 BATCH_SIZE = 2048
EVAL_BATCH_SIZE = 4000
NUM_NEG = 4 NUM_NEG = 4
...@@ -112,8 +113,8 @@ class BaseTest(tf.test.TestCase): ...@@ -112,8 +113,8 @@ class BaseTest(tf.test.TestCase):
def test_end_to_end(self): def test_end_to_end(self):
ncf_dataset, _ = data_preprocessing.instantiate_pipeline( ncf_dataset, _ = data_preprocessing.instantiate_pipeline(
dataset=DATASET, data_dir=self.temp_data_dir, dataset=DATASET, data_dir=self.temp_data_dir,
batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE, num_data_readers=2, batch_size=BATCH_SIZE, eval_batch_size=EVAL_BATCH_SIZE,
num_neg=NUM_NEG) num_data_readers=2, num_neg=NUM_NEG)
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
......
...@@ -23,7 +23,6 @@ from __future__ import division ...@@ -23,7 +23,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib import contextlib
import gc
import heapq import heapq
import math import math
import multiprocessing import multiprocessing
...@@ -48,183 +47,6 @@ from official.utils.logs import logger ...@@ -48,183 +47,6 @@ from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
_TOP_K = 10 # Top-k list for evaluation
# keys for evaluation metrics
_HR_KEY = "HR"
_NDCG_KEY = "NDCG"
def get_hit_rate_and_ndcg(predicted_scores_by_user, items_by_user, top_k=_TOP_K,
match_mlperf=False):
"""Returns the hit rate and the normalized DCG for evaluation.
`predicted_scores_by_user` and `items_by_user` are parallel NumPy arrays with
shape (num_users, num_items) such that `predicted_scores_by_user[i, j]` is the
predicted score that user `i` would rate item `items_by_user[i][j]`.
`items_by_user[i, 0]` is the item that user `i` interacted with, while
`items_by_user[i, 1:] are items that user `i` did not interact with. The goal
of the NCF model to give a high score for `predicted_scores_by_user[i, 0]`
compared to `predicted_scores_by_user[i, 1:]`, and the returned HR and NDCG
will be higher the more successful the model is at this goal.
If `match_mlperf` is True, then the HR and NDCG computations are done in a
slightly unusual way to match the MLPerf reference implementation.
Specifically, if `items_by_user[i, :]` contains duplicate items, it will be
treated as if the item only appeared once. Effectively, for duplicate items in
a row, the predicted score for all but one of the items will be set to
-infinity
For example, suppose we have that following inputs:
predicted_scores_by_user: [[ 2, 3, 3],
[ 5, 4, 4]]
items_by_user: [[10, 20, 20],
[30, 40, 40]]
top_k: 2
Then with match_mlperf=True, the HR would be 2/2 = 1.0. With
match_mlperf=False, the HR would be 1/2 = 0.5. This is because each user has
predicted scores for only 2 unique items: 10 and 20 for the first user, and 30
and 40 for the second. Therefore, with match_mlperf=True, it's guarenteed the
first item's score is in the top 2. With match_mlperf=False, this function
would compute the first user's first item is not in the top 2, because item 20
has a higher score, and item 20 occurs twice.
Args:
predicted_scores_by_user: 2D Numpy array of the predicted scores.
`predicted_scores_by_user[i, j]` is the predicted score that user `i`
would rate item `items_by_user[i][j]`.
items_by_user: 2d numpy array of the item IDs. For user `i`,
`items_by_user[i][0]` is the itme that user `i` interacted with, while
`predicted_scores_by_user[i, 1:] are items that user `i` did not interact
with.
top_k: Only consider the highest rated `top_k` items per user. The HR and
NDCG for that user will only be nonzero if the predicted score for that
user's first item is in the `top_k` top scores.
match_mlperf: If True, compute HR and NDCG slightly differently to match the
MLPerf reference implementation.
Returns:
(hr, ndcg) tuple of floats, averaged across all users.
"""
num_users = predicted_scores_by_user.shape[0]
zero_indices = np.zeros((num_users, 1), dtype=np.int32)
if match_mlperf:
predicted_scores_by_user = predicted_scores_by_user.copy()
items_by_user = items_by_user.copy()
# For each user, sort the items and predictions by increasing item number.
# We use mergesort since it's the only stable sort, which we need to be
# equivalent to the MLPerf reference implementation.
sorted_items_indices = items_by_user.argsort(kind="mergesort")
sorted_items = items_by_user[
np.arange(num_users)[:, np.newaxis], sorted_items_indices]
sorted_predictions = predicted_scores_by_user[
np.arange(num_users)[:, np.newaxis], sorted_items_indices]
# For items that occur more than once in a user's row, set the predicted
# score of the subsequent occurrences to -infinity, which effectively
# removes them from the array.
diffs = sorted_items[:, :-1] - sorted_items[:, 1:]
diffs = np.concatenate(
[np.ones((diffs.shape[0], 1), dtype=diffs.dtype), diffs], axis=1)
predicted_scores_by_user = np.where(diffs, sorted_predictions, -np.inf)
# After this block, `zero_indices` will be a (num_users, 1) shaped array
# indicating, for each user, the index of item of value 0 in
# `sorted_items_indices`. This item is the one we want to check if it is in
# the top_k items.
zero_indices = np.array(np.where(sorted_items_indices == 0))
assert np.array_equal(zero_indices[0, :], np.arange(num_users))
zero_indices = zero_indices[1, :, np.newaxis]
# NumPy has an np.argparition() method, however log(1000) is so small that
# sorting the whole array is simpler and fast enough.
top_indicies = np.argsort(predicted_scores_by_user, axis=1)[:, -top_k:]
top_indicies = np.flip(top_indicies, axis=1)
# Both HR and NDCG vectorized computation takes advantage of the fact that if
# the positive example for a user is not in the top k, that index does not
# appear. That is to say: hit_ind.shape[0] <= num_users
hit_ind = np.argwhere(np.equal(top_indicies, zero_indices))
hr = hit_ind.shape[0] / num_users
ndcg = np.sum(np.log(2) / np.log(hit_ind[:, 1] + 2)) / num_users
return hr, ndcg
def evaluate_model(estimator, ncf_dataset, pred_input_fn):
# type: (tf.estimator.Estimator, prepare.NCFDataset, typing.Callable) -> dict
"""Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items)
among the randomly chosen 999 items that are not interacted by the user.
The performance of the ranked list is judged by Hit Ratio (HR) and Normalized
Discounted Cumulative Gain (NDCG).
For evaluation, the ranked list is truncated at 10 for both metrics. As such,
the HR intuitively measures whether the test item is present on the top-10
list, and the NDCG accounts for the position of the hit by assigning higher
scores to hits at top ranks. Both metrics are calculated for each test user,
and the average scores are reported.
Args:
estimator: The Estimator.
ncf_dataset: An NCFDataSet object, which contains the information about
test/eval dataset, such as:
num_users: How many unique users are in the eval set.
test_data: The points which are used for consistent evaluation. These
are already included in the pred_input_fn.
pred_input_fn: The input function for the test data.
Returns:
eval_results: A dict of evaluation results for benchmark logging.
eval_results = {
_HR_KEY: hr,
_NDCG_KEY: ndcg,
tf.GraphKeys.GLOBAL_STEP: global_step
}
where hr is an integer indicating the average HR scores across all users,
ndcg is an integer representing the average NDCG scores across all users,
and global_step is the global step
"""
tf.logging.info("Computing predictions for eval set...")
# Get predictions
predictions = estimator.predict(input_fn=pred_input_fn,
yield_single_examples=False)
predictions = list(predictions)
prediction_batches = [p[movielens.RATING_COLUMN] for p in predictions]
item_batches = [p[movielens.ITEM_COLUMN] for p in predictions]
# Reshape the predicted scores and items. Each user takes one row.
prediction_with_padding = np.concatenate(prediction_batches, axis=0)
predicted_scores_by_user = prediction_with_padding[
:ncf_dataset.num_users * (1 + rconst.NUM_EVAL_NEGATIVES)]\
.reshape(ncf_dataset.num_users, -1)
item_with_padding = np.concatenate(item_batches, axis=0)
items_by_user = item_with_padding[
:ncf_dataset.num_users * (1 + rconst.NUM_EVAL_NEGATIVES)]\
.reshape(ncf_dataset.num_users, -1)
tf.logging.info("Computing metrics...")
hr, ndcg = get_hit_rate_and_ndcg(predicted_scores_by_user, items_by_user,
match_mlperf=FLAGS.ml_perf)
global_step = estimator.get_variable_value(tf.GraphKeys.GLOBAL_STEP)
eval_results = {
_HR_KEY: hr,
_NDCG_KEY: ndcg,
tf.GraphKeys.GLOBAL_STEP: global_step
}
return eval_results
def construct_estimator(num_gpus, model_dir, params, batch_size, def construct_estimator(num_gpus, model_dir, params, batch_size,
eval_batch_size): eval_batch_size):
...@@ -274,7 +96,7 @@ def construct_estimator(num_gpus, model_dir, params, batch_size, ...@@ -274,7 +96,7 @@ def construct_estimator(num_gpus, model_dir, params, batch_size,
model_fn=neumf_model.neumf_model_fn, model_fn=neumf_model.neumf_model_fn,
use_tpu=False, use_tpu=False,
train_batch_size=1, train_batch_size=1,
predict_batch_size=eval_batch_size, eval_batch_size=eval_batch_size,
params=tpu_params, params=tpu_params,
config=run_config) config=run_config)
...@@ -305,7 +127,15 @@ def run_ncf(_): ...@@ -305,7 +127,15 @@ def run_ncf(_):
num_gpus = flags_core.get_num_gpus(FLAGS) num_gpus = flags_core.get_num_gpus(FLAGS)
batch_size = distribution_utils.per_device_batch_size( batch_size = distribution_utils.per_device_batch_size(
int(FLAGS.batch_size), num_gpus) int(FLAGS.batch_size), num_gpus)
eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size) eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size)
eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1
if eval_batch_size % eval_per_user:
eval_batch_size = eval_batch_size // eval_per_user * eval_per_user
tf.logging.warning(
"eval examples per user does not evenly divide eval_batch_size. "
"Overriding to {}".format(eval_batch_size))
ncf_dataset, cleanup_fn = data_preprocessing.instantiate_pipeline( ncf_dataset, cleanup_fn = data_preprocessing.instantiate_pipeline(
dataset=FLAGS.dataset, data_dir=FLAGS.data_dir, dataset=FLAGS.dataset, data_dir=FLAGS.data_dir,
batch_size=batch_size, batch_size=batch_size,
...@@ -329,6 +159,7 @@ def run_ncf(_): ...@@ -329,6 +159,7 @@ def run_ncf(_):
"model_layers": [int(layer) for layer in FLAGS.layers], "model_layers": [int(layer) for layer in FLAGS.layers],
"mf_regularization": FLAGS.mf_regularization, "mf_regularization": FLAGS.mf_regularization,
"mlp_reg_layers": [float(reg) for reg in FLAGS.mlp_regularization], "mlp_reg_layers": [float(reg) for reg in FLAGS.mlp_regularization],
"num_neg": FLAGS.num_neg,
"use_tpu": FLAGS.tpu is not None, "use_tpu": FLAGS.tpu is not None,
"tpu": FLAGS.tpu, "tpu": FLAGS.tpu,
"tpu_zone": FLAGS.tpu_zone, "tpu_zone": FLAGS.tpu_zone,
...@@ -336,13 +167,15 @@ def run_ncf(_): ...@@ -336,13 +167,15 @@ def run_ncf(_):
"beta1": FLAGS.beta1, "beta1": FLAGS.beta1,
"beta2": FLAGS.beta2, "beta2": FLAGS.beta2,
"epsilon": FLAGS.epsilon, "epsilon": FLAGS.epsilon,
"match_mlperf": FLAGS.ml_perf,
}, batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size) }, batch_size=flags.FLAGS.batch_size, eval_batch_size=eval_batch_size)
# Create hooks that log information about the training and metric values # Create hooks that log information about the training and metric values
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, FLAGS.hooks,
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
batch_size=FLAGS.batch_size # for ExamplesPerSecondHook batch_size=FLAGS.batch_size, # for ExamplesPerSecondHook
tensors_to_log={"cross_entropy": "cross_entropy"}
) )
run_params = { run_params = {
"batch_size": FLAGS.batch_size, "batch_size": FLAGS.batch_size,
...@@ -367,7 +200,6 @@ def run_ncf(_): ...@@ -367,7 +200,6 @@ def run_ncf(_):
tf.logging.info("Starting a training cycle: {}/{}".format( tf.logging.info("Starting a training cycle: {}/{}".format(
cycle_index + 1, total_training_cycle)) cycle_index + 1, total_training_cycle))
# Train the model # Train the model
train_input_fn, train_record_dir, batch_count = \ train_input_fn, train_record_dir, batch_count = \
data_preprocessing.make_train_input_fn(ncf_dataset=ncf_dataset) data_preprocessing.make_train_input_fn(ncf_dataset=ncf_dataset)
...@@ -381,23 +213,19 @@ def run_ncf(_): ...@@ -381,23 +213,19 @@ def run_ncf(_):
steps=batch_count) steps=batch_count)
tf.gfile.DeleteRecursively(train_record_dir) tf.gfile.DeleteRecursively(train_record_dir)
# Evaluate the model tf.logging.info("Beginning evaluation.")
eval_results = evaluate_model( eval_results = eval_estimator.evaluate(pred_input_fn)
eval_estimator, ncf_dataset, pred_input_fn) tf.logging.info("Evaluation complete.")
# Benchmark the evaluation results # Benchmark the evaluation results
benchmark_logger.log_evaluation_result(eval_results) benchmark_logger.log_evaluation_result(eval_results)
# Log the HR and NDCG results. # Log the HR and NDCG results.
hr = eval_results[_HR_KEY] hr = eval_results[rconst.HR_KEY]
ndcg = eval_results[_NDCG_KEY] ndcg = eval_results[rconst.NDCG_KEY]
tf.logging.info( tf.logging.info(
"Iteration {}: HR = {:.4f}, NDCG = {:.4f}".format( "Iteration {}: HR = {:.4f}, NDCG = {:.4f}".format(
cycle_index + 1, hr, ndcg)) cycle_index + 1, hr, ndcg))
# Some of the NumPy vector math can be quite large and likes to stay in
# memory for a while.
gc.collect()
# If some evaluation threshold is met # If some evaluation threshold is met
if model_helpers.past_stop_threshold(FLAGS.hr_threshold, hr): if model_helpers.past_stop_threshold(FLAGS.hr_threshold, hr):
break break
...@@ -534,6 +362,11 @@ def define_ncf_flags(): ...@@ -534,6 +362,11 @@ def define_ncf_flags():
"training. However it is useful to confirm that a random seed is " "training. However it is useful to confirm that a random seed is "
"does indeed make the data pipeline deterministic.")) "does indeed make the data pipeline deterministic."))
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
.format(rconst.NUM_EVAL_NEGATIVES + 1))
def eval_size_check(eval_batch_size):
return int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES
if __name__ == "__main__": if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
......
...@@ -23,10 +23,56 @@ import math ...@@ -23,10 +23,56 @@ import math
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.recommendation import constants as rconst
from official.recommendation import neumf_model
from official.recommendation import ncf_main from official.recommendation import ncf_main
from official.recommendation import stat_utils
NUM_TRAIN_NEG = 4
class NcfTest(tf.test.TestCase): class NcfTest(tf.test.TestCase):
def setUp(self):
self.top_k_old = rconst.TOP_K
self.num_eval_negatives_old = rconst.NUM_EVAL_NEGATIVES
rconst.NUM_EVAL_NEGATIVES = 2
def tearDown(self):
rconst.NUM_EVAL_NEGATIVES = self.num_eval_negatives_old
rconst.TOP_K = self.top_k_old
def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
top_k=rconst.TOP_K, match_mlperf=False):
rconst.TOP_K = top_k
rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1
g = tf.Graph()
with g.as_default():
logits = tf.convert_to_tensor(
predicted_scores_by_user.reshape((-1, 1)), tf.float32)
softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
logits], axis=1)
duplicate_mask = tf.convert_to_tensor(
stat_utils.mask_duplicates(items_by_user, axis=1), tf.float32)
metric_ops = neumf_model.compute_eval_loss_and_metrics(
logits=logits, softmax_logits=softmax_logits,
duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
match_mlperf=match_mlperf).eval_metric_ops
hr = metric_ops[rconst.HR_KEY]
ndcg = metric_ops[rconst.NDCG_KEY]
init = [tf.global_variables_initializer(),
tf.local_variables_initializer()]
with self.test_session(graph=g) as sess:
sess.run(init)
return sess.run([hr[1], ndcg[1]])
def test_hit_rate_and_ndcg(self): def test_hit_rate_and_ndcg(self):
# Test with no duplicate items # Test with no duplicate items
predictions = np.array([ predictions = np.array([
...@@ -41,27 +87,32 @@ class NcfTest(tf.test.TestCase): ...@@ -41,27 +87,32 @@ class NcfTest(tf.test.TestCase):
[3, 2, 1], [3, 2, 1],
[2, 1, 3], [2, 1, 3],
]) ])
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 1)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
self.assertAlmostEqual(hr, 1 / 4) self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4) self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 2)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
self.assertAlmostEqual(hr, 2 / 4) self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 3)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 1, hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
match_mlperf=True) match_mlperf=True)
self.assertAlmostEqual(hr, 1 / 4) self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4) self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True)
self.assertAlmostEqual(hr, 2 / 4) self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
...@@ -80,35 +131,41 @@ class NcfTest(tf.test.TestCase): ...@@ -80,35 +131,41 @@ class NcfTest(tf.test.TestCase):
[1, 2, 3, 2], [1, 2, 3, 2],
[4, 3, 2, 1], [4, 3, 2, 1],
]) ])
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 1) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
self.assertAlmostEqual(hr, 1 / 4) self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4) self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 2)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
self.assertAlmostEqual(hr, 2 / 4) self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 3)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
self.assertAlmostEqual(hr, 2 / 4) self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(5)) / 4) 2 * math.log(2) / math.log(5)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 1, hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
match_mlperf=True) match_mlperf=True)
self.assertAlmostEqual(hr, 1 / 4) self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4) self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True)
self.assertAlmostEqual(hr, 2 / 4) self.assertAlmostEqual(hr, 2 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 4,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + math.log(2) / math.log(3) +
2 * math.log(2) / math.log(4)) / 4) 2 * math.log(2) / math.log(4)) / 4)
...@@ -127,36 +184,42 @@ class NcfTest(tf.test.TestCase): ...@@ -127,36 +184,42 @@ class NcfTest(tf.test.TestCase):
[2, 1, 1, 1], [2, 1, 1, 1],
[4, 2, 2, 1], [4, 2, 2, 1],
]) ])
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 1) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1)
self.assertAlmostEqual(hr, 0 / 4) self.assertAlmostEqual(hr, 0 / 4)
self.assertAlmostEqual(ndcg, 0 / 4) self.assertAlmostEqual(ndcg, 0 / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 2)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2)
self.assertAlmostEqual(hr, 1 / 4) self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 3)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3) +
3 * math.log(2) / math.log(4)) / 4) 3 * math.log(2) / math.log(4)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 4)
hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (math.log(2) / math.log(3) +
3 * math.log(2) / math.log(4)) / 4) 3 * math.log(2) / math.log(4)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 1, hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 1,
match_mlperf=True) match_mlperf=True)
self.assertAlmostEqual(hr, 1 / 4) self.assertAlmostEqual(hr, 1 / 4)
self.assertAlmostEqual(ndcg, 1 / 4) self.assertAlmostEqual(ndcg, 1 / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 2,
match_mlperf=True)
self.assertAlmostEqual(hr, 3 / 4) self.assertAlmostEqual(hr, 3 / 4)
self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3)) / 4) self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 3,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3) +
math.log(2) / math.log(4)) / 4) math.log(2) / math.log(4)) / 4)
hr, ndcg = ncf_main.get_hit_rate_and_ndcg(predictions, items, 4,
match_mlperf=True) hr, ndcg = self.get_hit_rate_and_ndcg(predictions, items, 4,
match_mlperf=True)
self.assertAlmostEqual(hr, 4 / 4) self.assertAlmostEqual(hr, 4 / 4)
self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3) + self.assertAlmostEqual(ndcg, (1 + 2 * math.log(2) / math.log(3) +
math.log(2) / math.log(4)) / 4) math.log(2) / math.log(4)) / 4)
......
...@@ -40,6 +40,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin ...@@ -40,6 +40,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from official.datasets import movielens # pylint: disable=g-bad-import-order from official.datasets import movielens # pylint: disable=g-bad-import-order
from official.recommendation import constants as rconst
from official.recommendation import stat_utils from official.recommendation import stat_utils
...@@ -53,6 +54,10 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -53,6 +54,10 @@ def neumf_model_fn(features, labels, mode, params):
logits = construct_model(users=users, items=items, params=params) logits = construct_model(users=users, items=items, params=params)
# Softmax with the first column of zeros is equivalent to sigmoid.
softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
logits], axis=1)
if mode == tf.estimator.ModeKeys.PREDICT: if mode == tf.estimator.ModeKeys.PREDICT:
predictions = { predictions = {
movielens.ITEM_COLUMN: items, movielens.ITEM_COLUMN: items,
...@@ -63,6 +68,12 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -63,6 +68,12 @@ def neumf_model_fn(features, labels, mode, params):
return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions) return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions)
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
elif mode == tf.estimator.ModeKeys.EVAL:
duplicate_mask = tf.cast(features[rconst.DUPLICATE_MASK], tf.float32)
return compute_eval_loss_and_metrics(
logits, softmax_logits, duplicate_mask, params["num_neg"],
params["match_mlperf"], params["use_tpu"])
elif mode == tf.estimator.ModeKeys.TRAIN: elif mode == tf.estimator.ModeKeys.TRAIN:
labels = tf.cast(labels, tf.int32) labels = tf.cast(labels, tf.int32)
optimizer = tf.train.AdamOptimizer( optimizer = tf.train.AdamOptimizer(
...@@ -71,15 +82,14 @@ def neumf_model_fn(features, labels, mode, params): ...@@ -71,15 +82,14 @@ def neumf_model_fn(features, labels, mode, params):
if params["use_tpu"]: if params["use_tpu"]:
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
# Softmax with the first column of ones is equivalent to sigmoid.
logits = tf.concat([tf.ones(logits.shape, dtype=logits.dtype), logits]
, axis=1)
loss = tf.losses.sparse_softmax_cross_entropy( loss = tf.losses.sparse_softmax_cross_entropy(
labels=labels, labels=labels,
logits=logits logits=softmax_logits
) )
# This tensor is used by logging hooks.
tf.identity(loss, name="cross_entropy")
global_step = tf.train.get_global_step() global_step = tf.train.get_global_step()
tvars = tf.trainable_variables() tvars = tf.trainable_variables()
gradients = optimizer.compute_gradients( gradients = optimizer.compute_gradients(
...@@ -191,3 +201,151 @@ def construct_model(users, items, params): ...@@ -191,3 +201,151 @@ def construct_model(users, items, params):
sys.stdout.flush() sys.stdout.flush()
return logits return logits
def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
softmax_logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
num_training_neg, # type: int
match_mlperf=False, # type: bool
use_tpu=False # type: bool
):
# type: (...) -> tf.estimator.EstimatorSpec
"""Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items)
among the randomly chosen 999 items that are not interacted by the user.
The performance of the ranked list is judged by Hit Ratio (HR) and Normalized
Discounted Cumulative Gain (NDCG).
For evaluation, the ranked list is truncated at 10 for both metrics. As such,
the HR intuitively measures whether the test item is present on the top-10
list, and the NDCG accounts for the position of the hit by assigning higher
scores to hits at top ranks. Both metrics are calculated for each test user,
and the average scores are reported.
If `match_mlperf` is True, then the HR and NDCG computations are done in a
slightly unusual way to match the MLPerf reference implementation.
Specifically, if the evaluation negatives contain duplicate items, it will be
treated as if the item only appeared once. Effectively, for duplicate items in
a row, the predicted score for all but one of the items will be set to
-infinity
For example, suppose we have that following inputs:
logits_by_user: [[ 2, 3, 3],
[ 5, 4, 4]]
items_by_user: [[10, 20, 20],
[30, 40, 40]]
# Note: items_by_user is not explicitly present. Instead the relevant \
information is contained within `duplicate_mask`
top_k: 2
Then with match_mlperf=True, the HR would be 2/2 = 1.0. With
match_mlperf=False, the HR would be 1/2 = 0.5. This is because each user has
predicted scores for only 2 unique items: 10 and 20 for the first user, and 30
and 40 for the second. Therefore, with match_mlperf=True, it's guaranteed the
first item's score is in the top 2. With match_mlperf=False, this function
would compute the first user's first item is not in the top 2, because item 20
has a higher score, and item 20 occurs twice.
Args:
logits: A tensor containing the predicted logits for each user. The shape
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits
for a user are grouped, and the first element of the group is the true
element.
softmax_logits: The same tensor, but with zeros left-appended.
duplicate_mask: A vector with the same shape as logits, with a value of 1
if the item corresponding to the logit at that position has already
appeared for that user.
num_training_neg: The number of negatives per positive during training.
match_mlperf: Use the MLPerf reference convention for computing rank.
use_tpu: Should the evaluation be performed on a TPU.
Returns:
An EstimatorSpec for evaluation.
"""
logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1))
duplicate_mask_by_user = tf.reshape(duplicate_mask,
(-1, rconst.NUM_EVAL_NEGATIVES + 1))
if match_mlperf:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user *= (1 - duplicate_mask_by_user)
logits_by_user += duplicate_mask_by_user * logits_by_user.dtype.min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices = tf.contrib.framework.argsort(
logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, 0), tf.int32)
sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.log(2.) / tf.log(tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
rconst.NUM_EVAL_NEGATIVES)
# Examples are provided by the eval Dataset in a structured format, so eval
# labels can be reconstructed on the fly.
eval_labels = tf.reshape(tf.one_hot(
tf.zeros(shape=(logits_by_user.shape[0],), dtype=tf.int32),
logits_by_user.shape[1], dtype=tf.int32), (-1,))
eval_labels_float = tf.cast(eval_labels, tf.float32)
# During evaluation, the ratio of negatives to positives is much higher
# than during training. (Typically 999 to 1 vs. 4 to 1) By adjusting the
# weights for the negative examples we compute a loss which is consistent with
# the training data. (And provides apples-to-apples comparison)
negative_scale_factor = num_training_neg / rconst.NUM_EVAL_NEGATIVES
example_weights = (
(eval_labels_float + (1 - eval_labels_float) * negative_scale_factor) *
(1 + rconst.NUM_EVAL_NEGATIVES) / (1 + num_training_neg))
# Tile metric weights back to logit dimensions
expanded_metric_weights = tf.reshape(tf.tile(
metric_weights[:, tf.newaxis], (1, rconst.NUM_EVAL_NEGATIVES + 1)), (-1,))
# ignore padded examples
example_weights *= tf.cast(expanded_metric_weights, tf.float32)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
logits=softmax_logits, labels=eval_labels, weights=example_weights)
def metric_fn(top_k_tensor, ndcg_tensor, weight_tensor):
return {
rconst.HR_KEY: tf.metrics.mean(top_k_tensor, weights=weight_tensor),
rconst.NDCG_KEY: tf.metrics.mean(ndcg_tensor, weights=weight_tensor),
}
if use_tpu:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL, loss=cross_entropy,
eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights]))
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=cross_entropy,
eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights)
)
...@@ -56,7 +56,7 @@ do ...@@ -56,7 +56,7 @@ do
--clean \ --clean \
--train_epochs 20 \ --train_epochs 20 \
--batch_size 2048 \ --batch_size 2048 \
--eval_batch_size 65536 \ --eval_batch_size 100000 \
--learning_rate 0.0005 \ --learning_rate 0.0005 \
--layers 256,256,128,64 --num_factors 64 \ --layers 256,256,128,64 --num_factors 64 \
--hr_threshold 0.635 \ --hr_threshold 0.635 \
...@@ -67,6 +67,12 @@ do ...@@ -67,6 +67,12 @@ do
END_TIME=$(date +%s) END_TIME=$(date +%s)
echo "Run ${i} complete: $(( $END_TIME - $START_TIME )) seconds." echo "Run ${i} complete: $(( $END_TIME - $START_TIME )) seconds."
# Don't fill up the local hard drive.
if [[ -z ${BUCKET} ]]; then
echo "Removing model directory to save space."
rm -r ${MODEL_DIR}
fi
done done
} |& tee "${LOCAL_TEST_DIR}/summary.log" } |& tee "${LOCAL_TEST_DIR}/summary.log"
...@@ -83,3 +83,36 @@ def sample_with_exclusion(num_items, positive_set, n, replacement=True): ...@@ -83,3 +83,36 @@ def sample_with_exclusion(num_items, positive_set, n, replacement=True):
# in practice tends to be quite ordered. # in practice tends to be quite ordered.
return negatives[:n] return negatives[:n]
def mask_duplicates(x, axis=1): # type: (np.ndarray, int) -> np.ndarray
"""Identify duplicates from sampling with replacement.
Args:
x: A 2D NumPy array of samples
axis: The axis along which to de-dupe.
Returns:
A NumPy array with the same shape as x with one if an element appeared
previously along axis 1, else zero.
"""
if axis != 1:
raise NotImplementedError
x_sort_ind = np.argsort(x, axis=1, kind="mergesort")
sorted_x = x[np.arange(x.shape[0])[:, np.newaxis], x_sort_ind]
# compute the indices needed to map values back to their original position.
inv_x_sort_ind = np.argsort(x_sort_ind, axis=1, kind="mergesort")
# Compute the difference of adjacent sorted elements.
diffs = sorted_x[:, :-1] - sorted_x[:, 1:]
# We are only interested in whether an element is zero. Therefore left padding
# with ones to restore the original shape is sufficient.
diffs = np.concatenate(
[np.ones((diffs.shape[0], 1), dtype=diffs.dtype), diffs], axis=1)
# Duplicate values will have a difference of zero. By definition the first
# element is never a duplicate.
return np.where(diffs[np.arange(x.shape[0])[:, np.newaxis],
inv_x_sort_ind], 0, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment