Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f1013390
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "bc410d52cc02fbf6b4e5eab9eb4e2c0a0ae6aa47"
Unverified
Commit
f1013390
authored
Dec 25, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 25, 2020
Browse files
Adding back the missing softmax in DARTS and support deduplication (#3224)
parent
9e5e0e3c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
3 deletions
+15
-3
nni/retiarii/trainer/pytorch/darts.py
nni/retiarii/trainer/pytorch/darts.py
+14
-3
nni/retiarii/trainer/pytorch/proxyless.py
nni/retiarii/trainer/pytorch/proxyless.py
+1
-0
No files found.
nni/retiarii/trainer/pytorch/darts.py
View file @
f1013390
...
...
@@ -6,6 +6,7 @@ import logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
..interface
import
BaseOneShotTrainer
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
...
...
@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
class
DartsLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
layer_choice
):
super
(
DartsLayerChoice
,
self
).
__init__
()
self
.
name
=
layer_choice
.
key
self
.
op_choices
=
nn
.
ModuleDict
(
layer_choice
.
named_children
())
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
def
forward
(
self
,
*
args
,
**
kwargs
):
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
return
torch
.
sum
(
op_results
*
self
.
alpha
.
view
(
*
alpha_shape
),
0
)
return
torch
.
sum
(
op_results
*
F
.
softmax
(
self
.
alpha
,
-
1
)
.
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
...
...
@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
class
DartsInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
(
DartsInputChoice
,
self
).
__init__
()
self
.
name
=
input_choice
.
key
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
def
forward
(
self
,
inputs
):
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
return
torch
.
sum
(
inputs
*
self
.
alpha
.
view
(
*
alpha_shape
),
0
)
return
torch
.
sum
(
inputs
*
F
.
softmax
(
self
.
alpha
,
-
1
)
.
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
...
...
@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
module
.
to
(
self
.
device
)
self
.
model_optim
=
optimizer
self
.
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
# use the same architecture weight for modules with duplicated names
ctrl_params
=
{}
for
_
,
m
in
self
.
nas_modules
:
if
m
.
name
in
ctrl_params
:
assert
m
.
alpha
.
size
()
==
ctrl_params
[
m
.
name
].
size
(),
'Size of parameters with the same label should be same.'
m
.
alpha
=
ctrl_params
[
m
.
name
]
else
:
ctrl_params
[
m
.
name
]
=
m
.
alpha
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
ctrl_params
.
values
()),
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
self
.
unrolled
=
unrolled
self
.
grad_clip
=
5.
...
...
nni/retiarii/trainer/pytorch/proxyless.py
View file @
f1013390
...
...
@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
module
.
to
(
self
.
device
)
self
.
optimizer
=
optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self
.
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
arc_learning_rate
,
weight_decay
=
0
,
betas
=
(
0
,
0.999
),
eps
=
1e-8
)
self
.
_init_dataloader
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment