Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
192a807b
Unverified
Commit
192a807b
authored
Dec 14, 2020
by
QuanluZhang
Committed by
GitHub
Dec 14, 2020
Browse files
[Retiarii] refactor based on the new launch approach (#3185)
parent
80394047
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
3 deletions
+47
-3
test/retiarii_test/mnasnet/base_mnasnet.py
test/retiarii_test/mnasnet/base_mnasnet.py
+3
-3
test/retiarii_test/mnasnet/mutator.py
test/retiarii_test/mnasnet/mutator.py
+0
-0
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+44
-0
test/retiarii_test/simple_strategy.py
test/retiarii_test/simple_strategy.py
+0
-0
No files found.
test/
convert_tes
t/base_mnasnet.py
→
test/
retiarii_test/mnasne
t/base_mnasnet.py
View file @
192a807b
...
@@ -9,6 +9,7 @@ import sys
...
@@ -9,6 +9,7 @@ import sys
from
pathlib
import
Path
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
resolve
().
parents
[
2
]))
sys
.
path
.
append
(
str
(
Path
(
__file__
).
resolve
().
parents
[
2
]))
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
# 1.0 - tensorflow.
...
@@ -109,7 +110,7 @@ def _get_depths(depths, alpha):
...
@@ -109,7 +110,7 @@ def _get_depths(depths, alpha):
rather than down. """
rather than down. """
return
[
_round_to_multiple_of
(
depth
*
alpha
,
8
)
for
depth
in
depths
]
return
[
_round_to_multiple_of
(
depth
*
alpha
,
8
)
for
depth
in
depths
]
@
register_module
()
class
MNASNet
(
nn
.
Module
):
class
MNASNet
(
nn
.
Module
):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
implements the B1 variant of the model.
...
@@ -126,8 +127,7 @@ class MNASNet(nn.Module):
...
@@ -126,8 +127,7 @@ class MNASNet(nn.Module):
def
__init__
(
self
,
alpha
,
depths
,
convops
,
kernel_sizes
,
num_layers
,
def
__init__
(
self
,
alpha
,
depths
,
convops
,
kernel_sizes
,
num_layers
,
skips
,
num_classes
=
1000
,
dropout
=
0.2
):
skips
,
num_classes
=
1000
,
dropout
=
0.2
):
super
(
MNASNet
,
self
).
__init__
(
alpha
,
depths
,
convops
,
kernel_sizes
,
num_layers
,
super
(
MNASNet
,
self
).
__init__
()
skips
,
num_classes
,
dropout
)
assert
alpha
>
0.0
assert
alpha
>
0.0
assert
len
(
depths
)
==
len
(
convops
)
==
len
(
kernel_sizes
)
==
len
(
num_layers
)
==
len
(
skips
)
==
7
assert
len
(
depths
)
==
len
(
convops
)
==
len
(
kernel_sizes
)
==
len
(
num_layers
)
==
len
(
skips
)
==
7
self
.
alpha
=
alpha
self
.
alpha
=
alpha
...
...
test/
convert_tes
t/mutator.py
→
test/
retiarii_test/mnasne
t/mutator.py
View file @
192a807b
File moved
test/
convert_tes
t/test.py
→
test/
retiarii_test/mnasne
t/test.py
View file @
192a807b
...
@@ -3,20 +3,11 @@ import sys
...
@@ -3,20 +3,11 @@ import sys
import
torch
import
torch
from
pathlib
import
Path
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
resolve
().
parents
[
2
]))
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
from
nni.retiarii.converter.visualize
import
visualize_model
from
nni.retiarii.codegen.pytorch
import
model_to_pytorch_script
from
nni.retiarii
import
nn
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.utils
import
TraceClassArguments
from
base_mnasnet
import
MNASNet
from
base_mnasnet
import
MNASNet
from
nni.experiment
import
RetiariiExperiment
,
RetiariiEx
p
Config
from
nni.
retiarii.
experiment
import
RetiariiExperiment
,
RetiariiEx
e
Config
#from simple_strategy import SimpleStrategy
#from tpe_strategy import TPEStrategy
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.strategies
import
TPEStrategy
from
mutator
import
BlockMutator
from
mutator
import
BlockMutator
...
@@ -27,7 +18,6 @@ if __name__ == '__main__':
...
@@ -27,7 +18,6 @@ if __name__ == '__main__':
_DEFAULT_KERNEL_SIZES
=
[
3
,
3
,
5
,
5
,
3
,
5
,
3
]
_DEFAULT_KERNEL_SIZES
=
[
3
,
3
,
5
,
5
,
3
,
5
,
3
]
_DEFAULT_NUM_LAYERS
=
[
1
,
3
,
3
,
3
,
2
,
4
,
1
]
_DEFAULT_NUM_LAYERS
=
[
1
,
3
,
3
,
3
,
2
,
4
,
1
]
with
TraceClassArguments
()
as
tca
:
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
...
@@ -36,15 +26,6 @@ if __name__ == '__main__':
...
@@ -36,15 +26,6 @@ if __name__ == '__main__':
optimizer_kwargs
=
{
"lr"
:
1e-3
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
trainer_kwargs
=
{
"max_epochs"
:
1
})
'''script_module = torch.jit.script(base_model)
model = convert_to_graph(script_module, base_model, tca.recorded_arguments)
code_script = model_to_pytorch_script(model)
print(code_script)
print("Model: ", model)
graph_ir = model._dump()
print(graph_ir)
visualize_model(graph_ir)'''
# new interface
# new interface
applied_mutators
=
[]
applied_mutators
=
[]
applied_mutators
.
append
(
BlockMutator
(
'mutable_0'
))
applied_mutators
.
append
(
BlockMutator
(
'mutable_0'
))
...
@@ -52,11 +33,12 @@ if __name__ == '__main__':
...
@@ -52,11 +33,12 @@ if __name__ == '__main__':
simple_startegy
=
TPEStrategy
()
simple_startegy
=
TPEStrategy
()
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
applied_mutators
,
simple_startegy
,
tca
)
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
applied_mutators
,
simple_startegy
)
exp_config
=
RetiariiEx
p
Config
.
create_template
(
'local'
)
exp_config
=
RetiariiEx
e
Config
(
'local'
)
exp_config
.
experiment_name
=
'mnasnet_search'
exp_config
.
experiment_name
=
'mnasnet_search'
exp_config
.
trial_concurrency
=
2
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
10
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
test/
convert
_test/simple_strategy.py
→
test/
retiarii
_test/simple_strategy.py
View file @
192a807b
File moved
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment