Unverified Commit 34706ba0 authored by Jared T Nielsen's avatar Jared T Nielsen Committed by GitHub
Browse files

Allow for None gradients in GradientAccumulator. (#4372)

parent edf9ac11
...@@ -217,7 +217,7 @@ class GradientAccumulator(object): ...@@ -217,7 +217,7 @@ class GradientAccumulator(object):
"""The accumulated gradients on the current replica.""" """The accumulated gradients on the current replica."""
if not self._gradients: if not self._gradients:
raise ValueError("The accumulator should be called first to initialize the gradients") raise ValueError("The accumulator should be called first to initialize the gradients")
return list(gradient.value() for gradient in self._gradients) return list(gradient.value() if gradient is not None else gradient for gradient in self._gradients)
def __call__(self, gradients): def __call__(self, gradients):
"""Accumulates :obj:`gradients` on the current replica.""" """Accumulates :obj:`gradients` on the current replica."""
...@@ -231,6 +231,8 @@ class GradientAccumulator(object): ...@@ -231,6 +231,8 @@ class GradientAccumulator(object):
synchronization=tf.VariableSynchronization.ON_READ, synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
) )
if gradient is not None
else gradient
for gradient in gradients for gradient in gradients
] ]
) )
...@@ -238,7 +240,8 @@ class GradientAccumulator(object): ...@@ -238,7 +240,8 @@ class GradientAccumulator(object):
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients))) raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
for accum_gradient, gradient in zip(self._gradients, gradients): for accum_gradient, gradient in zip(self._gradients, gradients):
accum_gradient.assign_add(gradient) if accum_gradient is not None and gradient is not None:
accum_gradient.assign_add(gradient)
self._accum_steps.assign_add(1) self._accum_steps.assign_add(1)
...@@ -248,4 +251,5 @@ class GradientAccumulator(object): ...@@ -248,4 +251,5 @@ class GradientAccumulator(object):
return return
self._accum_steps.assign(0) self._accum_steps.assign(0)
for gradient in self._gradients: for gradient in self._gradients:
gradient.assign(tf.zeros_like(gradient)) if gradient is not None:
gradient.assign(tf.zeros_like(gradient))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment