Commit e4c97f32 authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding set_grads_to_None option to fp16_optimizer

parent e4af2d90
......@@ -184,17 +184,28 @@ class FP16_Optimizer(object):
def __setstate__(self, state):
raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
def zero_grad(self):
def zero_grad(self, set_grads_to_None=False):
"""
Zero fp32 and fp16 parameter grads.
"""
# In principle, only the .grad attributes of the model params need to be zeroed,
# because gradients are copied into the FP32 master params. However, we zero
# all gradients owned by the optimizer, just to be safe:
self.optimizer.zero_grad()
for group in self.optimizer.param_groups:
for p in group['params']:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
# Zero fp16 gradients owned by the model:
for fp16_group in self.fp16_groups:
for param in fp16_group:
if set_grads_to_None:
param.grad = None
else:
if param.grad is not None:
param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment