Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4b2dcab3
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6958e2bf8efa01ece8ff07324ad82673ea9fec8f"
Unverified
Commit
4b2dcab3
authored
Dec 27, 2021
by
Yuge Zhang
Committed by
GitHub
Dec 27, 2021
Browse files
Minor fixes to SPOS and ProxylessNAS examples (#4420)
parent
64ea284f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
32 deletions
+25
-32
docs/en_US/NAS/SPOS.rst
docs/en_US/NAS/SPOS.rst
+1
-1
examples/nas/oneshot/proxylessnas/main.py
examples/nas/oneshot/proxylessnas/main.py
+19
-30
examples/nas/oneshot/spos/search.py
examples/nas/oneshot/spos/search.py
+5
-1
No files found.
docs/en_US/NAS/SPOS.rst
View file @
4b2dcab3
...
...
@@ -72,7 +72,7 @@ Step 3. Train for Evaluation
.. code-block:: bash
python
scratch
.py
python
evaluation
.py
By default, it will use ``architecture_final.json``. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with ``--fixed-arc`` option.
...
...
examples/nas/oneshot/proxylessnas/main.py
View file @
4b2dcab3
...
...
@@ -5,12 +5,11 @@ import sys
from
argparse
import
ArgumentParser
import
torch
from
nni.algorithms.nas.pytorch.proxylessnas
import
ProxylessNasTrainer
from
torchvision
import
transforms
from
nni.retiarii.fixed
import
fixed_arch
import
datasets
from
model
import
SearchMobileNet
from
nni.algorithms.nas.pytorch.proxylessnas
import
ProxylessNasTrainer
from
putils
import
LabelSmoothingLoss
,
accuracy
,
get_parameters
from
retrain
import
Retrain
...
...
@@ -40,7 +39,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--resize_scale"
,
default
=
0.08
,
type
=
float
)
parser
.
add_argument
(
"--distort_color"
,
default
=
'normal'
,
type
=
str
,
choices
=
[
'normal'
,
'strong'
,
'None'
])
# configurations for training mode
parser
.
add_argument
(
"--train_mode"
,
default
=
'search'
,
type
=
str
,
choices
=
[
'search_v1'
,
'search'
,
'retrain'
])
parser
.
add_argument
(
"--train_mode"
,
default
=
'search'
,
type
=
str
,
choices
=
[
'search'
,
'retrain'
])
# configurations for search
parser
.
add_argument
(
"--checkpoint_path"
,
default
=
'./search_mobile_net.pt'
,
type
=
str
)
parser
.
add_argument
(
"--arch_path"
,
default
=
'./arch_path.pt'
,
type
=
str
)
...
...
@@ -53,6 +52,17 @@ if __name__ == "__main__":
logger
.
error
(
'When --train_mode is retrain, --exported_arch_path must be specified.'
)
sys
.
exit
(
-
1
)
if
args
.
train_mode
==
'retrain'
:
assert
os
.
path
.
isfile
(
args
.
exported_arch_path
),
\
"exported_arch_path {} should be a file."
.
format
(
args
.
exported_arch_path
)
with
fixed_arch
(
args
.
exported_arch_path
):
model
=
SearchMobileNet
(
width_stages
=
[
int
(
i
)
for
i
in
args
.
width_stages
.
split
(
','
)],
n_cell_stages
=
[
int
(
i
)
for
i
in
args
.
n_cell_stages
.
split
(
','
)],
stride_stages
=
[
int
(
i
)
for
i
in
args
.
stride_stages
.
split
(
','
)],
n_classes
=
1000
,
dropout_rate
=
args
.
dropout_rate
,
bn_param
=
(
args
.
bn_momentum
,
args
.
bn_eps
))
else
:
model
=
SearchMobileNet
(
width_stages
=
[
int
(
i
)
for
i
in
args
.
width_stages
.
split
(
','
)],
n_cell_stages
=
[
int
(
i
)
for
i
in
args
.
n_cell_stages
.
split
(
','
)],
stride_stages
=
[
int
(
i
)
for
i
in
args
.
stride_stages
.
split
(
','
)],
...
...
@@ -125,28 +135,7 @@ if __name__ == "__main__":
trainer
.
fit
()
print
(
'Final architecture:'
,
trainer
.
export
())
json
.
dump
(
trainer
.
export
(),
open
(
'checkpoint.json'
,
'w'
))
elif
args
.
train_mode
==
'search_v1'
:
# this is architecture search
logger
.
info
(
'Creating ProxylessNasTrainer...'
)
trainer
=
ProxylessNasTrainer
(
model
,
model_optim
=
optimizer
,
train_loader
=
data_provider
.
train
,
valid_loader
=
data_provider
.
valid
,
device
=
device
,
warmup
=
args
.
warmup
,
ckpt_path
=
args
.
checkpoint_path
,
arch_path
=
args
.
arch_path
)
logger
.
info
(
'Start to train with ProxylessNasTrainer...'
)
trainer
.
train
()
logger
.
info
(
'Training done'
)
trainer
.
export
(
args
.
arch_path
)
logger
.
info
(
'Best architecture exported in %s'
,
args
.
arch_path
)
elif
args
.
train_mode
==
'retrain'
:
# this is retrain
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
assert
os
.
path
.
isfile
(
args
.
exported_arch_path
),
\
"exported_arch_path {} should be a file."
.
format
(
args
.
exported_arch_path
)
apply_fixed_architecture
(
model
,
args
.
exported_arch_path
)
trainer
=
Retrain
(
model
,
optimizer
,
device
,
data_provider
,
n_epochs
=
300
)
trainer
.
run
()
examples/nas/oneshot/spos/search.py
View file @
4b2dcab3
...
...
@@ -50,6 +50,9 @@ def _main(port):
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.49139968
,
0.48215827
,
0.44653124
],
[
0.24703233
,
0.24348505
,
0.26158768
])
]
# FIXME
# CIFAR10 is used here temporarily.
# Actually we should load weight from supernet and evaluate on imagenet.
train_dataset
=
serialize
(
CIFAR10
,
'data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
(
transf
+
normalize
))
test_dataset
=
serialize
(
CIFAR10
,
'data'
,
train
=
False
,
transform
=
transforms
.
Compose
(
normalize
))
...
...
@@ -57,7 +60,8 @@ def _main(port):
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
64
),
max_epochs
=
2
,
gpus
=
1
)
simple_strategy
=
strategy
.
RegularizedEvolution
(
model_filter
=
LatencyFilter
(
threshold
=
100
,
predictor
=
base_predictor
),
population_size
=
2
,
cycles
=
2
)
simple_strategy
=
strategy
.
RegularizedEvolution
(
model_filter
=
LatencyFilter
(
threshold
=
100
,
predictor
=
base_predictor
),
sample_size
=
1
,
population_size
=
2
,
cycles
=
2
)
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
strategy
=
simple_strategy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment