Commit 48343d94 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding retain_graph=True option

parent 4212b3e9
...@@ -422,7 +422,7 @@ class FP16_Optimizer(object): ...@@ -422,7 +422,7 @@ class FP16_Optimizer(object):
return retval return retval
def backward(self, loss, update_master_grads=True): def backward(self, loss, update_master_grads=True, retain_graph=False):
""" """
:attr:`backward` performs the following conceptual steps: :attr:`backward` performs the following conceptual steps:
...@@ -456,6 +456,7 @@ class FP16_Optimizer(object): ...@@ -456,6 +456,7 @@ class FP16_Optimizer(object):
Args: Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above). loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).
Example:: Example::
...@@ -479,7 +480,7 @@ class FP16_Optimizer(object): ...@@ -479,7 +480,7 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy # a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency. # discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float()) self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if update_master_grads: if update_master_grads:
self.update_master_grads() self.update_master_grads()
...@@ -512,14 +513,23 @@ class FP16_Optimizer(object): ...@@ -512,14 +513,23 @@ class FP16_Optimizer(object):
List of lists (one list for each parameter group). The list for each parameter group List of lists (one list for each parameter group). The list for each parameter group
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
""" """
raise NotImplementedError("Currently not implemented, working on it...")
fp32_grads_each_group = []
if self.overflow: if self.overflow:
print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
"Gradients are currently invalid (may be inf, nan, or stale). Returning None.") "Gradients are currently invalid (may be inf, nan, or stale). Returning None.")
return None return None
else: else:
return None # The optimizer owns only references to master params.
master_grads_data = []
for param_group in self.optimizer.param_groups:
master_grads_this_group = []
for param in param_group['params']:
if param.grad is not None:
master_grads_this_group.append(param.grad.data)
else:
master_grads_this_group.append(None)
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self): def _get_loss_scale(self):
......
...@@ -40,9 +40,9 @@ class LossScaler: ...@@ -40,9 +40,9 @@ class LossScaler:
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss*self.loss_scale
scaled_loss.backward() scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler: class DynamicLossScaler:
""" """
...@@ -127,9 +127,9 @@ class DynamicLossScaler: ...@@ -127,9 +127,9 @@ class DynamicLossScaler:
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss*self.loss_scale
scaled_loss.backward() scaled_loss.backward(retain_graph=retain_graph)
############################################################## ##############################################################
# Example usage below here -- assuming it's in a separate file # Example usage below here -- assuming it's in a separate file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment