Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e408e146
"vscode:/vscode.git/clone" did not exist on "0e59ecc71c88ee98a44b53eb13fb1909b14b6518"
Unverified
Commit
e408e146
authored
Aug 24, 2020
by
Tab Zhang
Committed by
GitHub
Aug 24, 2020
Browse files
Search space zoo example fix (#2801)
parent
593d2d20
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
5 additions
and
9 deletions
+5
-9
examples/nas/search_space_zoo/darts_example.py
examples/nas/search_space_zoo/darts_example.py
+1
-1
examples/nas/search_space_zoo/darts_stack_cells.py
examples/nas/search_space_zoo/darts_stack_cells.py
+2
-2
examples/nas/search_space_zoo/enas_macro_example.py
examples/nas/search_space_zoo/enas_macro_example.py
+0
-2
examples/nas/search_space_zoo/enas_micro_example.py
examples/nas/search_space_zoo/enas_micro_example.py
+1
-2
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
+1
-2
No files found.
examples/nas/search_space_zoo/darts_example.py
View file @
e408e146
...
...
@@ -14,7 +14,7 @@ from nni.nas.pytorch.darts import DartsTrainer
from
utils
import
accuracy
from
nni.nas.pytorch.search_space_zoo
import
DartsCell
from
darts_s
earch_space
import
DartsStackedCells
from
darts_s
tack_cells
import
DartsStackedCells
logger
=
logging
.
getLogger
(
'nni'
)
...
...
examples/nas/search_space_zoo/darts_stack_cells.py
View file @
e408e146
...
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import
torch.nn
as
nn
import
ops
from
nni.nas.pytorch.search_space_zoo.darts_ops
import
DropPath
class
DartsStackedCells
(
nn
.
Module
):
...
...
@@ -79,5 +79,5 @@ class DartsStackedCells(nn.Module):
def
drop_path_prob
(
self
,
p
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ops
.
DropPath
):
if
isinstance
(
module
,
DropPath
):
module
.
p
=
p
examples/nas/search_space_zoo/enas_macro_example.py
View file @
e408e146
...
...
@@ -58,7 +58,6 @@ if __name__ == "__main__":
parser
=
ArgumentParser
(
"enas"
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
...
...
@@ -71,7 +70,6 @@ if __name__ == "__main__":
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.05
,
momentum
=
0.9
,
weight_decay
=
1.0E-4
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
num_epochs
,
eta_min
=
0.001
)
trainer
=
enas
.
EnasTrainer
(
model
,
loss
=
criterion
,
metrics
=
accuracy
,
...
...
examples/nas/search_space_zoo/enas_micro_example.py
View file @
e408e146
...
...
@@ -62,7 +62,7 @@ class MicroNetwork(nn.Module):
reduction
=
False
if
layer_id
in
pool_layers
:
c_cur
,
reduction
=
c_p
*
2
,
True
self
.
layers
.
append
(
ENASMicroLayer
(
self
.
layers
,
num_nodes
,
c_pp
,
c_p
,
c_cur
,
reduction
))
self
.
layers
.
append
(
ENASMicroLayer
(
num_nodes
,
c_pp
,
c_p
,
c_cur
,
reduction
))
if
reduction
:
c_pp
=
c_p
=
c_cur
c_pp
,
c_p
=
c_p
,
c_cur
...
...
@@ -98,7 +98,6 @@ if __name__ == "__main__":
parser
=
ArgumentParser
(
"enas"
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
...
...
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
View file @
e408e146
...
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
"""
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
super
().
__init__
()
print
(
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
)
self
.
reduction
=
reduction
if
self
.
reduction
:
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
...
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
])
if
prev_labels
>
0
:
if
prev_labels
:
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
else
:
self
.
skipconnect
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment