Commit 6cfa81a1 authored by guptapriya's avatar guptapriya
Browse files

Remove metrics hack for dist strat

parent 51289259
......@@ -160,13 +160,9 @@ class MetricLayer(tf.keras.layers.Layer):
def call(self, inputs):
logits, targets = inputs[0], inputs[1]
# TODO(guptapriya): Remove this check when underlying issue to create
# metrics with dist strat in cross replica context is fixed.
if (tf.distribute.has_strategy() and
not tf.distribute.in_cross_replica_context()):
for mean, fn in self.metric_mean_fns:
m = mean(*fn(logits, targets))
self.add_metric(m)
for mean, fn in self.metric_mean_fns:
m = mean(*fn(logits, targets))
self.add_metric(m)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment