Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
72087f8a
"src/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "03f21987d429d5d25a0a867fc262c1fb6aa95e18"
Unverified
Commit
72087f8a
authored
Dec 20, 2021
by
Yuge Zhang
Committed by
GitHub
Dec 20, 2021
Browse files
Fix DARTS 2nd order (#4385)
parent
53e6bb3a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+2
-2
No files found.
nni/retiarii/oneshot/pytorch/darts.py
View file @
72087f8a
...
@@ -214,7 +214,7 @@ class DartsTrainer(BaseOneShotTrainer):
...
@@ -214,7 +214,7 @@ class DartsTrainer(BaseOneShotTrainer):
# calculate unrolled loss on validation data
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
# keep gradients for model here for compute hessian
_
,
loss
=
self
.
_logits_and_loss
(
val_X
,
val_y
)
_
,
loss
=
self
.
_logits_and_loss
(
val_X
,
val_y
)
w_model
,
w_ctrl
=
tuple
(
self
.
model
.
parameters
()),
tuple
([
c
.
alpha
for
c
in
self
.
nas_modules
])
w_model
,
w_ctrl
=
tuple
(
self
.
model
.
parameters
()),
tuple
([
c
.
alpha
for
_
,
c
in
self
.
nas_modules
])
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
w_grads
=
torch
.
autograd
.
grad
(
loss
,
w_model
+
w_ctrl
)
d_model
,
d_ctrl
=
w_grads
[:
len
(
w_model
)],
w_grads
[
len
(
w_model
):]
d_model
,
d_ctrl
=
w_grads
[:
len
(
w_model
)],
w_grads
[
len
(
w_model
):]
...
@@ -267,7 +267,7 @@ class DartsTrainer(BaseOneShotTrainer):
...
@@ -267,7 +267,7 @@ class DartsTrainer(BaseOneShotTrainer):
p
+=
e
*
d
p
+=
e
*
d
_
,
loss
=
self
.
_logits_and_loss
(
trn_X
,
trn_y
)
_
,
loss
=
self
.
_logits_and_loss
(
trn_X
,
trn_y
)
dalphas
.
append
(
torch
.
autograd
.
grad
(
loss
,
[
c
.
alpha
for
c
in
self
.
nas_modules
]))
dalphas
.
append
(
torch
.
autograd
.
grad
(
loss
,
[
c
.
alpha
for
_
,
c
in
self
.
nas_modules
]))
dalpha_pos
,
dalpha_neg
=
dalphas
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
dalpha_pos
,
dalpha_neg
=
dalphas
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian
=
[(
p
-
n
)
/
(
2.
*
eps
)
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
hessian
=
[(
p
-
n
)
/
(
2.
*
eps
)
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment