Commit 3d2a7e7f authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

Address code review comments

parent d0186041
......@@ -62,14 +62,12 @@ class MetricLayer(tf.keras.layers.Layer):
def __init__(self, params):
super(MetricLayer, self).__init__()
self.params = params
def build(self, input_shape):
self.metric = tf.keras.metrics.Mean(name=rconst.HR_METRIC_NAME)
def call(self, inputs):
logits, dup_mask = inputs
in_top_k, metric_weights = metric_fn(logits, dup_mask, self.params)
self.add_metric(self.metric(in_top_k, metric_weights))
self.add_metric(self.metric(in_top_k, sample_weight=metric_weights))
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment