Commit 9a0986d1 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Fix a bug that ncf_keras model cannot be serialized as JSON.

PiperOrigin-RevId: 314664026
parent 869a4806
...@@ -37,21 +37,22 @@ from official.recommendation import movielens ...@@ -37,21 +37,22 @@ from official.recommendation import movielens
from official.recommendation import ncf_common from official.recommendation import ncf_common
from official.recommendation import ncf_input_pipeline from official.recommendation import ncf_input_pipeline
from official.recommendation import neumf_model from official.recommendation import neumf_model
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def metric_fn(logits, dup_mask, params): def metric_fn(logits, dup_mask, match_mlperf):
dup_mask = tf.cast(dup_mask, tf.float32) dup_mask = tf.cast(dup_mask, tf.float32)
logits = tf.slice(logits, [0, 1], [-1, -1]) logits = tf.slice(logits, [0, 1], [-1, -1])
in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg( in_top_k, _, metric_weights, _ = neumf_model.compute_top_k_and_ndcg(
logits, logits,
dup_mask, dup_mask,
params["match_mlperf"]) match_mlperf)
metric_weights = tf.cast(metric_weights, tf.float32) metric_weights = tf.cast(metric_weights, tf.float32)
return in_top_k, metric_weights return in_top_k, metric_weights
...@@ -59,9 +60,16 @@ def metric_fn(logits, dup_mask, params): ...@@ -59,9 +60,16 @@ def metric_fn(logits, dup_mask, params):
class MetricLayer(tf.keras.layers.Layer): class MetricLayer(tf.keras.layers.Layer):
"""Custom layer of metrics for NCF model.""" """Custom layer of metrics for NCF model."""
def __init__(self, params): def __init__(self, match_mlperf):
super(MetricLayer, self).__init__() super(MetricLayer, self).__init__()
self.params = params self.match_mlperf = match_mlperf
def get_config(self):
return {"match_mlperf": self.match_mlperf}
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def call(self, inputs, training=False): def call(self, inputs, training=False):
logits, dup_mask = inputs logits, dup_mask = inputs
...@@ -70,7 +78,7 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -70,7 +78,7 @@ class MetricLayer(tf.keras.layers.Layer):
hr_sum = 0.0 hr_sum = 0.0
hr_count = 0.0 hr_count = 0.0
else: else:
metric, metric_weights = metric_fn(logits, dup_mask, self.params) metric, metric_weights = metric_fn(logits, dup_mask, self.match_mlperf)
hr_sum = tf.reduce_sum(metric * metric_weights) hr_sum = tf.reduce_sum(metric * metric_weights)
hr_count = tf.reduce_sum(metric_weights) hr_count = tf.reduce_sum(metric_weights)
...@@ -89,6 +97,13 @@ class LossLayer(tf.keras.layers.Layer): ...@@ -89,6 +97,13 @@ class LossLayer(tf.keras.layers.Layer):
self.loss = tf.keras.losses.SparseCategoricalCrossentropy( self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction="sum") from_logits=True, reduction="sum")
def get_config(self):
return {"loss_normalization_factor": self.loss_normalization_factor}
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
def call(self, inputs): def call(self, inputs):
logits, labels, valid_pt_mask_input = inputs logits, labels, valid_pt_mask_input = inputs
loss = self.loss( loss = self.loss(
...@@ -409,7 +424,7 @@ def run_ncf_custom_training(params, ...@@ -409,7 +424,7 @@ def run_ncf_custom_training(params,
softmax_logits = keras_model(features) softmax_logits = keras_model(features)
in_top_k, metric_weights = metric_fn(softmax_logits, in_top_k, metric_weights = metric_fn(softmax_logits,
features[rconst.DUPLICATE_MASK], features[rconst.DUPLICATE_MASK],
params) params["match_mlperf"])
hr_sum = tf.reduce_sum(in_top_k * metric_weights) hr_sum = tf.reduce_sum(in_top_k * metric_weights)
hr_count = tf.reduce_sum(metric_weights) hr_count = tf.reduce_sum(metric_weights)
return hr_sum, hr_count return hr_sum, hr_count
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment