Unverified Commit a9b87c9a authored by hroken's avatar hroken Committed by GitHub
Browse files

Bug-fixed for _compute_hessian in Pytorch NAS Darts (#3058)


Co-authored-by: default avatarhuwei <huw@abc.com>
parent 16dc45b1
...@@ -210,5 +210,5 @@ class DartsTrainer(Trainer): ...@@ -210,5 +210,5 @@ class DartsTrainer(Trainer):
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters())) dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian return hessian
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment