Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e408e146
Unverified
Commit
e408e146
authored
Aug 24, 2020
by
Tab Zhang
Committed by
GitHub
Aug 24, 2020
Browse files
Search space zoo example fix (#2801)
parent
593d2d20
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
5 additions
and
9 deletions
+5
-9
examples/nas/search_space_zoo/darts_example.py
examples/nas/search_space_zoo/darts_example.py
+1
-1
examples/nas/search_space_zoo/darts_stack_cells.py
examples/nas/search_space_zoo/darts_stack_cells.py
+2
-2
examples/nas/search_space_zoo/enas_macro_example.py
examples/nas/search_space_zoo/enas_macro_example.py
+0
-2
examples/nas/search_space_zoo/enas_micro_example.py
examples/nas/search_space_zoo/enas_micro_example.py
+1
-2
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
+1
-2
No files found.
examples/nas/search_space_zoo/darts_example.py
View file @
e408e146
...
@@ -14,7 +14,7 @@ from nni.nas.pytorch.darts import DartsTrainer
...
@@ -14,7 +14,7 @@ from nni.nas.pytorch.darts import DartsTrainer
from
utils
import
accuracy
from
utils
import
accuracy
from
nni.nas.pytorch.search_space_zoo
import
DartsCell
from
nni.nas.pytorch.search_space_zoo
import
DartsCell
from
darts_s
earch_space
import
DartsStackedCells
from
darts_s
tack_cells
import
DartsStackedCells
logger
=
logging
.
getLogger
(
'nni'
)
logger
=
logging
.
getLogger
(
'nni'
)
...
...
examples/nas/search_space_zoo/darts_stack_cells.py
View file @
e408e146
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
torch.nn
as
nn
import
torch.nn
as
nn
import
ops
from
nni.nas.pytorch.search_space_zoo.darts_ops
import
DropPath
class
DartsStackedCells
(
nn
.
Module
):
class
DartsStackedCells
(
nn
.
Module
):
...
@@ -79,5 +79,5 @@ class DartsStackedCells(nn.Module):
...
@@ -79,5 +79,5 @@ class DartsStackedCells(nn.Module):
def
drop_path_prob
(
self
,
p
):
def
drop_path_prob
(
self
,
p
):
for
module
in
self
.
modules
():
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ops
.
DropPath
):
if
isinstance
(
module
,
DropPath
):
module
.
p
=
p
module
.
p
=
p
examples/nas/search_space_zoo/enas_macro_example.py
View file @
e408e146
...
@@ -58,7 +58,6 @@ if __name__ == "__main__":
...
@@ -58,7 +58,6 @@ if __name__ == "__main__":
parser
=
ArgumentParser
(
"enas"
)
parser
=
ArgumentParser
(
"enas"
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -71,7 +70,6 @@ if __name__ == "__main__":
...
@@ -71,7 +70,6 @@ if __name__ == "__main__":
criterion
=
nn
.
CrossEntropyLoss
()
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.05
,
momentum
=
0.9
,
weight_decay
=
1.0E-4
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
0.05
,
momentum
=
0.9
,
weight_decay
=
1.0E-4
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
num_epochs
,
eta_min
=
0.001
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
CosineAnnealingLR
(
optimizer
,
T_max
=
num_epochs
,
eta_min
=
0.001
)
trainer
=
enas
.
EnasTrainer
(
model
,
trainer
=
enas
.
EnasTrainer
(
model
,
loss
=
criterion
,
loss
=
criterion
,
metrics
=
accuracy
,
metrics
=
accuracy
,
...
...
examples/nas/search_space_zoo/enas_micro_example.py
View file @
e408e146
...
@@ -62,7 +62,7 @@ class MicroNetwork(nn.Module):
...
@@ -62,7 +62,7 @@ class MicroNetwork(nn.Module):
reduction
=
False
reduction
=
False
if
layer_id
in
pool_layers
:
if
layer_id
in
pool_layers
:
c_cur
,
reduction
=
c_p
*
2
,
True
c_cur
,
reduction
=
c_p
*
2
,
True
self
.
layers
.
append
(
ENASMicroLayer
(
self
.
layers
,
num_nodes
,
c_pp
,
c_p
,
c_cur
,
reduction
))
self
.
layers
.
append
(
ENASMicroLayer
(
num_nodes
,
c_pp
,
c_p
,
c_cur
,
reduction
))
if
reduction
:
if
reduction
:
c_pp
=
c_p
=
c_cur
c_pp
=
c_p
=
c_cur
c_pp
,
c_p
=
c_p
,
c_cur
c_pp
,
c_p
=
c_p
,
c_cur
...
@@ -98,7 +98,6 @@ if __name__ == "__main__":
...
@@ -98,7 +98,6 @@ if __name__ == "__main__":
parser
=
ArgumentParser
(
"enas"
)
parser
=
ArgumentParser
(
"enas"
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
128
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
10
,
type
=
int
)
# parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--epochs"
,
default
=
None
,
type
=
int
,
help
=
"Number of epochs (default: macro 310, micro 150)"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
parser
.
add_argument
(
"--visualization"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
src/sdk/pynni/nni/nas/pytorch/search_space_zoo/enas_cell.py
View file @
e408e146
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
...
@@ -79,7 +79,6 @@ class ENASMicroLayer(nn.Module):
"""
"""
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
def
__init__
(
self
,
num_nodes
,
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
):
super
().
__init__
()
super
().
__init__
()
print
(
in_channels_pp
,
in_channels_p
,
out_channels
,
reduction
)
self
.
reduction
=
reduction
self
.
reduction
=
reduction
if
self
.
reduction
:
if
self
.
reduction
:
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
self
.
reduce0
=
FactorizedReduce
(
in_channels_pp
,
out_channels
,
affine
=
False
)
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
...
@@ -160,7 +159,7 @@ class ENASMacroLayer(mutables.MutableScope):
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'avg'
,
in_filters
,
out_filters
,
3
,
1
,
1
),
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
PoolBranch
(
'max'
,
in_filters
,
out_filters
,
3
,
1
,
1
)
])
])
if
prev_labels
>
0
:
if
prev_labels
:
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
self
.
skipconnect
=
mutables
.
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
None
)
else
:
else
:
self
.
skipconnect
=
None
self
.
skipconnect
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment