Unverified Commit 4bd0d89d authored by shizhiw's avatar shizhiw Committed by GitHub
Browse files

Merge pull request #5557 from tensorflow/shizhiw_20181017

 Refactor neumf_model.py to support users who just need top_k and ndcg tensors.
parents 69b01644 3ec25e5d
...@@ -298,39 +298,8 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor ...@@ -298,39 +298,8 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
Returns: Returns:
An EstimatorSpec for evaluation. An EstimatorSpec for evaluation.
""" """
in_top_k, ndcg, metric_weights, logits_by_user = compute_top_k_and_ndcg(
logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1)) logits, duplicate_mask, match_mlperf)
duplicate_mask_by_user = tf.reshape(duplicate_mask,
(-1, rconst.NUM_EVAL_NEGATIVES + 1))
if match_mlperf:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user *= (1 - duplicate_mask_by_user)
logits_by_user += duplicate_mask_by_user * logits_by_user.dtype.min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices = tf.contrib.framework.argsort(
logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, 0), tf.int32)
sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.log(2.) / tf.log(tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
rconst.NUM_EVAL_NEGATIVES)
# Examples are provided by the eval Dataset in a structured format, so eval # Examples are provided by the eval Dataset in a structured format, so eval
# labels can be reconstructed on the fly. # labels can be reconstructed on the fly.
...@@ -375,3 +344,60 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor ...@@ -375,3 +344,60 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
loss=cross_entropy, loss=cross_entropy,
eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights) eval_metric_ops=metric_fn(in_top_k, ndcg, metric_weights)
) )
def compute_top_k_and_ndcg(logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
match_mlperf=False # type: bool
):
"""Compute inputs of metric calculation.
Args:
logits: A tensor containing the predicted logits for each user. The shape
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits
for a user are grouped, and the first element of the group is the true
element.
duplicate_mask: A vector with the same shape as logits, with a value of 1
if the item corresponding to the logit at that position has already
appeared for that user.
match_mlperf: Use the MLPerf reference convention for computing rank.
Returns:
is_top_k, ndcg and weights, all of which has size (num_users_in_batch,), and
logits_by_user which has size
(num_users_in_batch, (rconst.NUM_EVAL_NEGATIVES + 1)).
"""
logits_by_user = tf.reshape(logits, (-1, rconst.NUM_EVAL_NEGATIVES + 1))
duplicate_mask_by_user = tf.reshape(duplicate_mask,
(-1, rconst.NUM_EVAL_NEGATIVES + 1))
if match_mlperf:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user *= (1 - duplicate_mask_by_user)
logits_by_user += duplicate_mask_by_user * logits_by_user.dtype.min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices = tf.contrib.framework.argsort(
logits_by_user, axis=1, direction="DESCENDING")
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position = tf.cast(tf.equal(sort_indices, 0), tf.int32)
sparse_positions = tf.multiply(
one_hot_position, tf.range(logits_by_user.shape[1])[tf.newaxis, :])
position_vector = tf.reduce_sum(sparse_positions, axis=1)
in_top_k = tf.cast(tf.less(position_vector, rconst.TOP_K), tf.float32)
ndcg = tf.log(2.) / tf.log(tf.cast(position_vector, tf.float32) + 2)
ndcg *= in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights = tf.not_equal(tf.reduce_sum(duplicate_mask_by_user, axis=1),
rconst.NUM_EVAL_NEGATIVES)
return in_top_k, ndcg, metric_weights, logits_by_user
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment