Commit 6cfa81a1 authored by guptapriya's avatar guptapriya
Browse files

Remove metrics hack for dist strat

parent 51289259
...@@ -160,13 +160,9 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -160,13 +160,9 @@ class MetricLayer(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
logits, targets = inputs[0], inputs[1] logits, targets = inputs[0], inputs[1]
# TODO(guptapriya): Remove this check when underlying issue to create for mean, fn in self.metric_mean_fns:
# metrics with dist strat in cross replica context is fixed. m = mean(*fn(logits, targets))
if (tf.distribute.has_strategy() and self.add_metric(m)
not tf.distribute.in_cross_replica_context()):
for mean, fn in self.metric_mean_fns:
m = mean(*fn(logits, targets))
self.add_metric(m)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment