Commit f0a8be5d authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

try #1 to fix CTL

parent 70704b94
...@@ -68,7 +68,7 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -68,7 +68,7 @@ class MetricLayer(tf.keras.layers.Layer):
return inputs[0] return inputs[0]
def _get_train_and_eval_data(producer, params): def _get_train_and_eval_data(producer, params):
"""Returns the datasets for training and evalutating.""" """Returns the datasets for training and evalutating."""
def preprocess_train_input(features, labels): def preprocess_train_input(features, labels):
...@@ -313,8 +313,7 @@ def run_ncf(_): ...@@ -313,8 +313,7 @@ def run_ncf(_):
"""Computes loss and applied gradient per replica.""" """Computes loss and applied gradient per replica."""
features, labels = inputs features, labels = inputs
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
softmax_logits = keras_model([features[movielens.USER_COLUMN], softmax_logits = keras_model(features)
features[movielens.ITEM_COLUMN]])
loss = loss_object(labels, softmax_logits, loss = loss_object(labels, softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK]) sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync)) loss *= (1.0 / (batch_size*strategy.num_replicas_in_sync))
...@@ -336,8 +335,7 @@ def run_ncf(_): ...@@ -336,8 +335,7 @@ def run_ncf(_):
def step_fn(inputs): def step_fn(inputs):
"""Computes eval metrics per replica.""" """Computes eval metrics per replica."""
features, _ = inputs features, _ = inputs
softmax_logits = keras_model([features[movielens.USER_COLUMN], softmax_logits = keras_model(features)
features[movielens.ITEM_COLUMN]])
logits = tf.slice(softmax_logits, [0, 0, 1], [-1, -1, -1]) logits = tf.slice(softmax_logits, [0, 0, 1], [-1, -1, -1])
dup_mask = features[rconst.DUPLICATE_MASK] dup_mask = features[rconst.DUPLICATE_MASK]
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg( in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
...@@ -412,7 +410,7 @@ def run_ncf(_): ...@@ -412,7 +410,7 @@ def run_ncf(_):
train_history = history.history train_history = history.history
train_loss = train_history["loss"][-1] train_loss = train_history["loss"][-1]
stats = build_stats(train_loss, eval_results, time_callback) stats = build_stats(train_loss, eval_results, None) #, time_callback)
return stats return stats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment