Commit 6cfa81a1 authored by guptapriya's avatar guptapriya
Browse files

Remove metrics hack for dist strat

parent 51289259
...@@ -160,10 +160,6 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -160,10 +160,6 @@ class MetricLayer(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
logits, targets = inputs[0], inputs[1] logits, targets = inputs[0], inputs[1]
# TODO(guptapriya): Remove this check when underlying issue to create
# metrics with dist strat in cross replica context is fixed.
if (tf.distribute.has_strategy() and
not tf.distribute.in_cross_replica_context()):
for mean, fn in self.metric_mean_fns: for mean, fn in self.metric_mean_fns:
m = mean(*fn(logits, targets)) m = mean(*fn(logits, targets))
self.add_metric(m) self.add_metric(m)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment