"src/vscode:/vscode.git/clone" did not exist on "ccf2c3118811ecb860cd95cc93910383a8a00063"
Unverified Commit 85b17833 authored by ngimel's avatar ngimel Committed by GitHub
Browse files

sgd supports zero-grad (#926)

parent 274cc063
......@@ -76,7 +76,8 @@ class FusedSGD(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
materialize_master_grads=True,
set_grad_none=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
......@@ -94,6 +95,7 @@ class FusedSGD(Optimizer):
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
import amp_C
......@@ -108,6 +110,14 @@ class FusedSGD(Optimizer):
for group in self.param_groups:
group.setdefault('nesterov', False)
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedSGD, self).zero_grad()
def get_momentums(self, params):
momentums = []
first_run = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment