Unverified Commit 72087f8a authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix DARTS 2nd order (#4385)

parent 53e6bb3a
......@@ -214,7 +214,7 @@ class DartsTrainer(BaseOneShotTrainer):
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple([c.alpha for c in self.nas_modules])
w_model, w_ctrl = tuple(self.model.parameters()), tuple([c.alpha for _, c in self.nas_modules])
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
......@@ -267,7 +267,7 @@ class DartsTrainer(BaseOneShotTrainer):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, [c.alpha for c in self.nas_modules]))
dalphas.append(torch.autograd.grad(loss, [c.alpha for _, c in self.nas_modules]))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment