Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
91000bc5
Commit
91000bc5
authored
Oct 17, 2018
by
Shawn Wang
Browse files
Refactor neumf_model.py to support users who just need top_k and ndcg tensors.
parent
69b01644
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
36 deletions
+63
-36
official/recommendation/neumf_model.py
official/recommendation/neumf_model.py
+63
-36
No files found.
official/recommendation/neumf_model.py
View file @
91000bc5
...
@@ -36,12 +36,13 @@ from __future__ import print_function
...
@@ -36,12 +36,13 @@ from __future__ import print_function
import
sys
import
sys
import
typing
import
typing
import
google3
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.datasets
import
movielens
# pylint: disable=g-bad-import-order
from
google3.third_party.tensorflow_models.
official.datasets
import
movielens
# pylint: disable=g-bad-import-order
from
official.recommendation
import
constants
as
rconst
from
google3.third_party.tensorflow_models.
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
stat_utils
from
google3.third_party.tensorflow_models.
official.recommendation
import
stat_utils
def
_sparse_to_dense_grads
(
grads_and_vars
):
def
_sparse_to_dense_grads
(
grads_and_vars
):
...
@@ -298,39 +299,8 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
...
@@ -298,39 +299,8 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
Returns:
Returns:
An EstimatorSpec for evaluation.
An EstimatorSpec for evaluation.
"""
"""
in_top_k
,
ndcg
,
metric_weights
,
logits_by_user
=
compute_top_k_and_ndcg
(
logits_by_user
=
tf
.
reshape
(
logits
,
(
-
1
,
rconst
.
NUM_EVAL_NEGATIVES
+
1
))
logits
,
duplicate_mask
,
match_mlperf
)
duplicate_mask_by_user
=
tf
.
reshape
(
duplicate_mask
,
(
-
1
,
rconst
.
NUM_EVAL_NEGATIVES
+
1
))
if
match_mlperf
:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user
*=
(
1
-
duplicate_mask_by_user
)
logits_by_user
+=
duplicate_mask_by_user
*
logits_by_user
.
dtype
.
min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices
=
tf
.
contrib
.
framework
.
argsort
(
logits_by_user
,
axis
=
1
,
direction
=
"DESCENDING"
)
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position
=
tf
.
cast
(
tf
.
equal
(
sort_indices
,
0
),
tf
.
int32
)
sparse_positions
=
tf
.
multiply
(
one_hot_position
,
tf
.
range
(
logits_by_user
.
shape
[
1
])[
tf
.
newaxis
,
:])
position_vector
=
tf
.
reduce_sum
(
sparse_positions
,
axis
=
1
)
in_top_k
=
tf
.
cast
(
tf
.
less
(
position_vector
,
rconst
.
TOP_K
),
tf
.
float32
)
ndcg
=
tf
.
log
(
2.
)
/
tf
.
log
(
tf
.
cast
(
position_vector
,
tf
.
float32
)
+
2
)
ndcg
*=
in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights
=
tf
.
not_equal
(
tf
.
reduce_sum
(
duplicate_mask_by_user
,
axis
=
1
),
rconst
.
NUM_EVAL_NEGATIVES
)
# Examples are provided by the eval Dataset in a structured format, so eval
# Examples are provided by the eval Dataset in a structured format, so eval
# labels can be reconstructed on the fly.
# labels can be reconstructed on the fly.
...
@@ -375,3 +345,60 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
...
@@ -375,3 +345,60 @@ def compute_eval_loss_and_metrics(logits, # type: tf.Tensor
loss
=
cross_entropy
,
loss
=
cross_entropy
,
eval_metric_ops
=
metric_fn
(
in_top_k
,
ndcg
,
metric_weights
)
eval_metric_ops
=
metric_fn
(
in_top_k
,
ndcg
,
metric_weights
)
)
)
def
compute_top_k_and_ndcg
(
logits
,
# type: tf.Tensor
duplicate_mask
,
# type: tf.Tensor
match_mlperf
=
False
# type: bool
):
"""Compute inputs of metric calculation.
Args:
logits: A tensor containing the predicted logits for each user. The shape
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits
for a user are grouped, and the first element of the group is the true
element.
duplicate_mask: A vector with the same shape as logits, with a value of 1
if the item corresponding to the logit at that position has already
appeared for that user.
match_mlperf: Use the MLPerf reference convention for computing rank.
Returns:
is_top_k, ndcg and weights, all of which has size (num_users_in_batch,), and
logits_by_user which has size
(num_users_in_batch, (rconst.NUM_EVAL_NEGATIVES + 1)).
"""
logits_by_user
=
tf
.
reshape
(
logits
,
(
-
1
,
rconst
.
NUM_EVAL_NEGATIVES
+
1
))
duplicate_mask_by_user
=
tf
.
reshape
(
duplicate_mask
,
(
-
1
,
rconst
.
NUM_EVAL_NEGATIVES
+
1
))
if
match_mlperf
:
# Set duplicate logits to the min value for that dtype. The MLPerf
# reference dedupes during evaluation.
logits_by_user
*=
(
1
-
duplicate_mask_by_user
)
logits_by_user
+=
duplicate_mask_by_user
*
logits_by_user
.
dtype
.
min
# Determine the location of the first element in each row after the elements
# are sorted.
sort_indices
=
tf
.
contrib
.
framework
.
argsort
(
logits_by_user
,
axis
=
1
,
direction
=
"DESCENDING"
)
# Use matrix multiplication to extract the position of the true item from the
# tensor of sorted indices. This approach is chosen because both GPUs and TPUs
# perform matrix multiplications very quickly. This is similar to np.argwhere.
# However this is a special case because the target will only appear in
# sort_indices once.
one_hot_position
=
tf
.
cast
(
tf
.
equal
(
sort_indices
,
0
),
tf
.
int32
)
sparse_positions
=
tf
.
multiply
(
one_hot_position
,
tf
.
range
(
logits_by_user
.
shape
[
1
])[
tf
.
newaxis
,
:])
position_vector
=
tf
.
reduce_sum
(
sparse_positions
,
axis
=
1
)
in_top_k
=
tf
.
cast
(
tf
.
less
(
position_vector
,
rconst
.
TOP_K
),
tf
.
float32
)
ndcg
=
tf
.
log
(
2.
)
/
tf
.
log
(
tf
.
cast
(
position_vector
,
tf
.
float32
)
+
2
)
ndcg
*=
in_top_k
# If a row is a padded row, all but the first element will be a duplicate.
metric_weights
=
tf
.
not_equal
(
tf
.
reduce_sum
(
duplicate_mask_by_user
,
axis
=
1
),
rconst
.
NUM_EVAL_NEGATIVES
)
return
in_top_k
,
ndcg
,
metric_weights
,
logits_by_user
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment