Commit 48343d94 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding retain_graph=True option

parent 4212b3e9
......@@ -422,7 +422,7 @@ class FP16_Optimizer(object):
return retval
def backward(self, loss, update_master_grads=True):
def backward(self, loss, update_master_grads=True, retain_graph=False):
"""
:attr:`backward` performs the following conceptual steps:
......@@ -456,6 +456,7 @@ class FP16_Optimizer(object):
Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
Example::
......@@ -479,7 +480,7 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float())
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if update_master_grads:
self.update_master_grads()
......@@ -512,14 +513,23 @@ class FP16_Optimizer(object):
List of lists (one list for each parameter group). The list for each parameter group
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
"""
raise NotImplementedError("Currently not implemented, working on it...")
fp32_grads_each_group = []
if self.overflow:
print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
"Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
return None
else:
return None
# The optimizer owns only references to master params.
master_grads_data = []
for param_group in self.optimizer.param_groups:
master_grads_this_group = []
for param in param_group['params']:
if param.grad is not None:
master_grads_this_group.append(param.grad.data)
else:
master_grads_this_group.append(None)
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
......
......@@ -40,9 +40,9 @@ class LossScaler:
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
......@@ -127,9 +127,9 @@ class DynamicLossScaler:
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss):
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss.backward()
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment