Unverified Commit 34706ba0 authored by Jared T Nielsen's avatar Jared T Nielsen Committed by GitHub
Browse files

Allow for None gradients in GradientAccumulator. (#4372)

parent edf9ac11
......@@ -217,7 +217,7 @@ class GradientAccumulator(object):
"""The accumulated gradients on the current replica."""
if not self._gradients:
raise ValueError("The accumulator should be called first to initialize the gradients")
return list(gradient.value() for gradient in self._gradients)
return list(gradient.value() if gradient is not None else gradient for gradient in self._gradients)
def __call__(self, gradients):
"""Accumulates :obj:`gradients` on the current replica."""
......@@ -231,6 +231,8 @@ class GradientAccumulator(object):
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
if gradient is not None
else gradient
for gradient in gradients
]
)
......@@ -238,7 +240,8 @@ class GradientAccumulator(object):
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
for accum_gradient, gradient in zip(self._gradients, gradients):
accum_gradient.assign_add(gradient)
if accum_gradient is not None and gradient is not None:
accum_gradient.assign_add(gradient)
self._accum_steps.assign_add(1)
......@@ -248,4 +251,5 @@ class GradientAccumulator(object):
return
self._accum_steps.assign(0)
for gradient in self._gradients:
gradient.assign(tf.zeros_like(gradient))
if gradient is not None:
gradient.assign(tf.zeros_like(gradient))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment