"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "2ab84ed5d721fc0a36a2f77f2ffc10e55d7b15dd"
Commit 828b99cd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adds type annotations.

PiperOrigin-RevId: 309357174
parent a135c0f0
...@@ -549,7 +549,7 @@ def build_stats(loss, eval_result, time_callback): ...@@ -549,7 +549,7 @@ def build_stats(loss, eval_result, time_callback):
def main(_): def main(_):
run_ncf(FLAGS) logging.info("Result is %s", run_ncf(FLAGS))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -31,12 +31,15 @@ the two models by concatenating their last hidden layer. ...@@ -31,12 +31,15 @@ the two models by concatenating their last hidden layer.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import sys import sys
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from typing import Any, Dict, Text
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import ncf_common from official.recommendation import ncf_common
...@@ -133,14 +136,15 @@ def _strip_first_and_last_dimension(x, batch_size): ...@@ -133,14 +136,15 @@ def _strip_first_and_last_dimension(x, batch_size):
return tf.reshape(x[0, :], (batch_size,)) return tf.reshape(x[0, :], (batch_size,))
def construct_model(user_input, item_input, params): def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
# type: (tf.Tensor, tf.Tensor, dict) -> tf.keras.Model params: Dict[Text, Any]) -> tf.keras.Model:
"""Initialize NeuMF model. """Initialize NeuMF model.
Args: Args:
user_input: keras input layer for users user_input: keras input layer for users
item_input: keras input layer for items item_input: keras input layer for items
params: Dict of hyperparameters. params: Dict of hyperparameters.
Raises: Raises:
ValueError: if the first model layer is not even. ValueError: if the first model layer is not even.
Returns: Returns:
...@@ -232,13 +236,12 @@ def construct_model(user_input, item_input, params): ...@@ -232,13 +236,12 @@ def construct_model(user_input, item_input, params):
return model return model
def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor def _get_estimator_spec_with_metrics(logits: tf.Tensor,
softmax_logits, # type: tf.Tensor softmax_logits: tf.Tensor,
duplicate_mask, # type: tf.Tensor duplicate_mask: tf.Tensor,
num_training_neg, # type: int num_training_neg: int,
match_mlperf=False, # type: bool match_mlperf: bool = False,
use_tpu_spec=False # type: bool use_tpu_spec: bool = False):
):
"""Returns a EstimatorSpec that includes the metrics.""" """Returns a EstimatorSpec that includes the metrics."""
cross_entropy, \ cross_entropy, \
metric_fn, \ metric_fn, \
...@@ -264,13 +267,11 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor ...@@ -264,13 +267,11 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
) )
def compute_eval_loss_and_metrics_helper( def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
logits, # type: tf.Tensor softmax_logits: tf.Tensor,
softmax_logits, # type: tf.Tensor duplicate_mask: tf.Tensor,
duplicate_mask, # type: tf.Tensor num_training_neg: int,
num_training_neg, # type: int match_mlperf: bool = False):
match_mlperf=False # type: bool
):
"""Model evaluation with HR and NDCG metrics. """Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items) The evaluation protocol is to rank the test interacted item (truth items)
...@@ -312,19 +313,14 @@ def compute_eval_loss_and_metrics_helper( ...@@ -312,19 +313,14 @@ def compute_eval_loss_and_metrics_helper(
has a higher score, and item 20 occurs twice. has a higher score, and item 20 occurs twice.
Args: Args:
logits: A tensor containing the predicted logits for each user. The shape logits: A tensor containing the predicted logits for each user. The shape of
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits for a
for a user are grouped, and the last element of the group is the true user are grouped, and the last element of the group is the true element.
element.
softmax_logits: The same tensor, but with zeros left-appended. softmax_logits: The same tensor, but with zeros left-appended.
duplicate_mask: A vector with the same shape as logits, with a value of 1 if
duplicate_mask: A vector with the same shape as logits, with a value of 1 the item corresponding to the logit at that position has already appeared
if the item corresponding to the logit at that position has already for that user.
appeared for that user.
num_training_neg: The number of negatives per positive during training. num_training_neg: The number of negatives per positive during training.
match_mlperf: Use the MLPerf reference convention for computing rank. match_mlperf: Use the MLPerf reference convention for computing rank.
Returns: Returns:
...@@ -377,20 +373,18 @@ def compute_eval_loss_and_metrics_helper( ...@@ -377,20 +373,18 @@ def compute_eval_loss_and_metrics_helper(
return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights
def compute_top_k_and_ndcg(logits, # type: tf.Tensor def compute_top_k_and_ndcg(logits: tf.Tensor,
duplicate_mask, # type: tf.Tensor duplicate_mask: tf.Tensor,
match_mlperf=False # type: bool match_mlperf: bool = False):
):
"""Compute inputs of metric calculation. """Compute inputs of metric calculation.
Args: Args:
logits: A tensor containing the predicted logits for each user. The shape logits: A tensor containing the predicted logits for each user. The shape of
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits for a
for a user are grouped, and the first element of the group is the true user are grouped, and the first element of the group is the true element.
element. duplicate_mask: A vector with the same shape as logits, with a value of 1 if
duplicate_mask: A vector with the same shape as logits, with a value of 1 the item corresponding to the logit at that position has already appeared
if the item corresponding to the logit at that position has already for that user.
appeared for that user.
match_mlperf: Use the MLPerf reference convention for computing rank. match_mlperf: Use the MLPerf reference convention for computing rank.
Returns: Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment