Commit 828b99cd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Adds type annotations.

PiperOrigin-RevId: 309357174
parent a135c0f0
......@@ -549,7 +549,7 @@ def build_stats(loss, eval_result, time_callback):
def main(_):
run_ncf(FLAGS)
logging.info("Result is %s", run_ncf(FLAGS))
if __name__ == "__main__":
......
......@@ -31,12 +31,15 @@ the two models by concatenating their last hidden layer.
"""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import sys
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from typing import Any, Dict, Text
from official.recommendation import constants as rconst
from official.recommendation import movielens
from official.recommendation import ncf_common
......@@ -133,14 +136,15 @@ def _strip_first_and_last_dimension(x, batch_size):
return tf.reshape(x[0, :], (batch_size,))
def construct_model(user_input, item_input, params):
# type: (tf.Tensor, tf.Tensor, dict) -> tf.keras.Model
def construct_model(user_input: tf.Tensor, item_input: tf.Tensor,
params: Dict[Text, Any]) -> tf.keras.Model:
"""Initialize NeuMF model.
Args:
user_input: keras input layer for users
item_input: keras input layer for items
params: Dict of hyperparameters.
Raises:
ValueError: if the first model layer is not even.
Returns:
......@@ -232,13 +236,12 @@ def construct_model(user_input, item_input, params):
return model
def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
softmax_logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
num_training_neg, # type: int
match_mlperf=False, # type: bool
use_tpu_spec=False # type: bool
):
def _get_estimator_spec_with_metrics(logits: tf.Tensor,
softmax_logits: tf.Tensor,
duplicate_mask: tf.Tensor,
num_training_neg: int,
match_mlperf: bool = False,
use_tpu_spec: bool = False):
"""Returns a EstimatorSpec that includes the metrics."""
cross_entropy, \
metric_fn, \
......@@ -264,13 +267,11 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
)
def compute_eval_loss_and_metrics_helper(
logits, # type: tf.Tensor
softmax_logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
num_training_neg, # type: int
match_mlperf=False # type: bool
):
def compute_eval_loss_and_metrics_helper(logits: tf.Tensor,
softmax_logits: tf.Tensor,
duplicate_mask: tf.Tensor,
num_training_neg: int,
match_mlperf: bool = False):
"""Model evaluation with HR and NDCG metrics.
The evaluation protocol is to rank the test interacted item (truth items)
......@@ -312,19 +313,14 @@ def compute_eval_loss_and_metrics_helper(
has a higher score, and item 20 occurs twice.
Args:
logits: A tensor containing the predicted logits for each user. The shape
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits
for a user are grouped, and the last element of the group is the true
element.
logits: A tensor containing the predicted logits for each user. The shape of
logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits for a
user are grouped, and the last element of the group is the true element.
softmax_logits: The same tensor, but with zeros left-appended.
duplicate_mask: A vector with the same shape as logits, with a value of 1
if the item corresponding to the logit at that position has already
appeared for that user.
duplicate_mask: A vector with the same shape as logits, with a value of 1 if
the item corresponding to the logit at that position has already appeared
for that user.
num_training_neg: The number of negatives per positive during training.
match_mlperf: Use the MLPerf reference convention for computing rank.
Returns:
......@@ -377,20 +373,18 @@ def compute_eval_loss_and_metrics_helper(
return cross_entropy, metric_fn, in_top_k, ndcg, metric_weights
def compute_top_k_and_ndcg(logits, # type: tf.Tensor
duplicate_mask, # type: tf.Tensor
match_mlperf=False # type: bool
):
def compute_top_k_and_ndcg(logits: tf.Tensor,
duplicate_mask: tf.Tensor,
match_mlperf: bool = False):
"""Compute inputs of metric calculation.
Args:
logits: A tensor containing the predicted logits for each user. The shape
of logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits
for a user are grouped, and the first element of the group is the true
element.
duplicate_mask: A vector with the same shape as logits, with a value of 1
if the item corresponding to the logit at that position has already
appeared for that user.
logits: A tensor containing the predicted logits for each user. The shape of
logits is (num_users_per_batch * (1 + NUM_EVAL_NEGATIVES),) Logits for a
user are grouped, and the first element of the group is the true element.
duplicate_mask: A vector with the same shape as logits, with a value of 1 if
the item corresponding to the logit at that position has already appeared
for that user.
match_mlperf: Use the MLPerf reference convention for computing rank.
Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment