Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
354 additions
and
270 deletions
+354
-270
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+19
-15
nni/retiarii/execution/__init__.py
nni/retiarii/execution/__init__.py
+3
-0
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+1
-0
nni/retiarii/execution/benchmark.py
nni/retiarii/execution/benchmark.py
+9
-2
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+0
-1
nni/retiarii/execution/python.py
nni/retiarii/execution/python.py
+5
-1
nni/retiarii/execution/utils.py
nni/retiarii/execution/utils.py
+3
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+27
-16
nni/retiarii/fixed.py
nni/retiarii/fixed.py
+3
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+29
-16
nni/retiarii/hub/pytorch/mobilenetv3.py
nni/retiarii/hub/pytorch/mobilenetv3.py
+4
-4
nni/retiarii/hub/pytorch/nasbench101.py
nni/retiarii/hub/pytorch/nasbench101.py
+2
-7
nni/retiarii/hub/pytorch/nasbench201.py
nni/retiarii/hub/pytorch/nasbench201.py
+6
-2
nni/retiarii/hub/pytorch/nasnet.py
nni/retiarii/hub/pytorch/nasnet.py
+41
-36
nni/retiarii/hub/pytorch/proxylessnas.py
nni/retiarii/hub/pytorch/proxylessnas.py
+29
-21
nni/retiarii/hub/pytorch/shufflenet.py
nni/retiarii/hub/pytorch/shufflenet.py
+16
-9
nni/retiarii/hub/pytorch/utils.py
nni/retiarii/hub/pytorch/utils.py
+5
-0
nni/retiarii/integration.py
nni/retiarii/integration.py
+15
-12
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+6
-3
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+131
-125
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
18962129
...
@@ -4,10 +4,11 @@
...
@@ -4,10 +4,11 @@
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Callable
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Callable
,
Type
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
nn_functional
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torchmetrics
import
torchmetrics
import
torch.utils.data
as
torch_data
import
torch.utils.data
as
torch_data
...
@@ -124,12 +125,12 @@ class Lightning(Evaluator):
...
@@ -124,12 +125,12 @@ class Lightning(Evaluator):
if
other
is
None
:
if
other
is
None
:
return
False
return
False
if
hasattr
(
self
,
"function"
)
and
hasattr
(
other
,
"function"
):
if
hasattr
(
self
,
"function"
)
and
hasattr
(
other
,
"function"
):
eq_func
=
(
self
.
function
==
other
.
function
)
eq_func
=
getattr
(
self
,
"
function
"
)
==
getattr
(
other
,
"
function
"
)
elif
not
(
hasattr
(
self
,
"function"
)
or
hasattr
(
other
,
"function"
)):
elif
not
(
hasattr
(
self
,
"function"
)
or
hasattr
(
other
,
"function"
)):
eq_func
=
True
eq_func
=
True
if
hasattr
(
self
,
"arguments"
)
and
hasattr
(
other
,
"arguments"
):
if
hasattr
(
self
,
"arguments"
)
and
hasattr
(
other
,
"arguments"
):
eq_args
=
(
self
.
arguments
==
other
.
arguments
)
eq_args
=
getattr
(
self
,
"
arguments
"
)
==
getattr
(
other
,
"
arguments
"
)
elif
not
(
hasattr
(
self
,
"arguments"
)
or
hasattr
(
other
,
"arguments"
)):
elif
not
(
hasattr
(
self
,
"arguments"
)
or
hasattr
(
other
,
"arguments"
)):
eq_args
=
True
eq_args
=
True
...
@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
...
@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
### The following are some commonly used Lightning modules ###
class
_SupervisedLearningModule
(
LightningModule
):
class
_SupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torchmetrics
.
Metric
],
trainer
:
pl
.
Trainer
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
],
metrics
:
Dict
[
str
,
Type
[
torchmetrics
.
Metric
]],
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
Union
[
Path
,
str
,
bool
,
None
]
=
None
):
export_onnx
:
Union
[
Path
,
str
,
bool
,
None
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
save_hyperparameters
(
'criterion'
,
'optimizer'
,
'learning_rate'
,
'weight_decay'
)
self
.
save_hyperparameters
(
'criterion'
,
'optimizer'
,
'learning_rate'
,
'weight_decay'
)
...
@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
self
.
log
(
'test_'
+
name
,
metric
(
y_hat
,
y
),
prog_bar
=
True
)
self
.
log
(
'test_'
+
name
,
metric
(
y_hat
,
y
),
prog_bar
=
True
)
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
...
@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
class
_AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
class
_AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
return
super
().
update
(
nn
_
functional
.
softmax
(
pred
),
target
)
@
nni
.
trace
@
nni
.
trace
class
_ClassificationModule
(
_SupervisedLearningModule
):
class
_ClassificationModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
...
@@ -275,10 +279,10 @@ class Classification(Lightning):
...
@@ -275,10 +279,10 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
export_onnx
:
bool
=
True
,
...
@@ -291,10 +295,10 @@ class Classification(Lightning):
...
@@ -291,10 +295,10 @@ class Classification(Lightning):
@
nni
.
trace
@
nni
.
trace
class
_RegressionModule
(
_SupervisedLearningModule
):
class
_RegressionModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'mse'
:
torchmetrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torchmetrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
...
@@ -328,10 +332,10 @@ class Regression(Lightning):
...
@@ -328,10 +332,10 @@ class Regression(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
export_onnx
:
bool
=
True
,
...
...
nni/retiarii/execution/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.api
import
*
from
.api
import
*
nni/retiarii/execution/base.py
View file @
18962129
...
@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@
classmethod
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation_summary
=
get_mutation_summary
(
model
)
mutation_summary
=
get_mutation_summary
(
model
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator can not be None'
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
@
classmethod
@
classmethod
...
...
nni/retiarii/execution/benchmark.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
os
import
random
import
random
from
typing
import
Dict
,
Any
,
List
,
Optional
,
Union
,
Tuple
,
Callable
,
Iterable
from
typing
import
Dict
,
Any
,
List
,
Optional
,
Union
,
Tuple
,
Callable
,
Iterable
,
cast
from
..graph
import
Model
from
..graph
import
Model
from
..integration_api
import
receive_trial_parameters
from
..integration_api
import
receive_trial_parameters
...
@@ -39,6 +42,9 @@ class BenchmarkGraphData:
...
@@ -39,6 +42,9 @@ class BenchmarkGraphData:
def
load
(
data
)
->
'BenchmarkGraphData'
:
def
load
(
data
)
->
'BenchmarkGraphData'
:
return
BenchmarkGraphData
(
data
[
'mutation'
],
data
[
'benchmark'
],
data
[
'metric_name'
],
data
[
'db_path'
])
return
BenchmarkGraphData
(
data
[
'mutation'
],
data
[
'benchmark'
],
data
[
'metric_name'
],
data
[
'db_path'
])
def
__repr__
(
self
)
->
str
:
return
f
"BenchmarkGraphData(
{
self
.
mutation
}
,
{
self
.
benchmark
}
,
{
self
.
db_path
}
)"
class
BenchmarkExecutionEngine
(
BaseExecutionEngine
):
class
BenchmarkExecutionEngine
(
BaseExecutionEngine
):
"""
"""
...
@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
...
@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
graph_data
=
BenchmarkGraphData
.
load
(
receive_trial_parameters
())
graph_data
=
BenchmarkGraphData
.
load
(
receive_trial_parameters
())
assert
graph_data
.
db_path
is
not
None
,
f
'Invalid graph data because db_path is None:
{
graph_data
}
'
os
.
environ
[
'NASBENCHMARK_DIR'
]
=
graph_data
.
db_path
os
.
environ
[
'NASBENCHMARK_DIR'
]
=
graph_data
.
db_path
final
,
intermediates
=
cls
.
query_in_benchmark
(
graph_data
)
final
,
intermediates
=
cls
.
query_in_benchmark
(
graph_data
)
...
@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
...
@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
arch
=
t
arch
=
t
if
arch
is
None
:
if
arch
is
None
:
raise
ValueError
(
f
'Cannot identify architecture from mutation dict:
{
graph_data
.
mutation
}
'
)
raise
ValueError
(
f
'Cannot identify architecture from mutation dict:
{
graph_data
.
mutation
}
'
)
print
(
arch
)
return
_convert_to_final_and_intermediates
(
return
_convert_to_final_and_intermediates
(
query_nb101_trial_stats
(
arch
,
108
,
include_intermediates
=
True
),
query_nb101_trial_stats
(
arch
,
108
,
include_intermediates
=
True
),
'valid_acc'
'valid_acc'
...
@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
...
@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
benchmark_result
=
random
.
choice
(
benchmark_result
)
benchmark_result
=
random
.
choice
(
benchmark_result
)
else
:
else
:
benchmark_result
=
benchmark_result
[
0
]
benchmark_result
=
benchmark_result
[
0
]
benchmark_result
=
cast
(
dict
,
benchmark_result
)
return
benchmark_result
[
metric_name
],
[
i
[
metric_name
]
for
i
in
benchmark_result
[
'intermediates'
]
if
i
[
metric_name
]
is
not
None
]
return
benchmark_result
[
metric_name
],
[
i
[
metric_name
]
for
i
in
benchmark_result
[
'intermediates'
]
if
i
[
metric_name
]
is
not
None
]
nni/retiarii/execution/cgo_engine.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
os
import
os
import
random
import
random
...
...
nni/retiarii/execution/python.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
Any
,
Type
from
typing
import
Dict
,
Any
,
Type
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
...
@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@
classmethod
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation
=
get_mutation_dict
(
model
)
mutation
=
get_mutation_dict
(
model
)
graph_data
=
PythonGraphData
(
model
.
python_class
,
model
.
python_init_params
,
mutation
,
model
.
evaluator
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator is not available.'
graph_data
=
PythonGraphData
(
model
.
python_class
,
model
.
python_init_params
or
{},
mutation
,
model
.
evaluator
)
return
graph_data
return
graph_data
@
classmethod
@
classmethod
...
...
nni/retiarii/execution/utils.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
List
from
typing
import
Any
,
List
from
..graph
import
Model
from
..graph
import
Model
...
...
nni/retiarii/experiment/pytorch.py
View file @
18962129
...
@@ -11,7 +11,7 @@ from dataclasses import dataclass
...
@@ -11,7 +11,7 @@ from dataclasses import dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
subprocess
import
Popen
from
subprocess
import
Popen
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
,
cast
import
colorama
import
colorama
import
psutil
import
psutil
...
@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
...
@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
from
nni.experiment.config
import
utils
from
nni.experiment.config
import
utils
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.experiment.pipe
import
Pipe
from
nni.experiment.pipe
import
Pipe
from
nni.tools.nnictl.command_utils
import
kill_command
from
nni.tools.nnictl.command_utils
import
kill_command
...
@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
...
@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
Examples
Examples
--------
--------
Multi-trial NAS:
Multi-trial NAS:
>>> base_model = Net()
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
...
@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
...
@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config, 8081)
>>> exp.run(exp_config, 8081)
One-shot NAS:
One-shot NAS:
>>> base_model = Net()
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
...
@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
...
@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config)
>>> exp.run(exp_config)
Export top models:
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
... final_model = Net()
"""
"""
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
None
,
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
cast
(
Evaluator
,
None
)
,
applied_mutators
:
List
[
Mutator
]
=
None
,
strategy
:
BaseStrategy
=
None
,
applied_mutators
:
List
[
Mutator
]
=
cast
(
List
[
Mutator
],
None
)
,
strategy
:
BaseStrategy
=
cast
(
BaseStrategy
,
None
)
,
trainer
:
BaseOneShotTrainer
=
None
):
trainer
:
BaseOneShotTrainer
=
cast
(
BaseOneShotTrainer
,
None
)
)
:
if
trainer
is
not
None
:
if
trainer
is
not
None
:
warnings
.
warn
(
'Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
warnings
.
warn
(
'Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.'
,
DeprecationWarning
)
'Please consider specifying it as a positional argument, or use `evaluator`.'
,
DeprecationWarning
)
...
@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
...
@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
raise
ValueError
(
'Evaluator should not be none.'
)
raise
ValueError
(
'Evaluator should not be none.'
)
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self
.
config
:
RetiariiExeConfig
=
None
self
.
config
:
RetiariiExeConfig
=
cast
(
RetiariiExeConfig
,
None
)
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
base_model
=
base_model
self
.
evaluator
:
Evaluator
=
evaluator
self
.
evaluator
:
Union
[
Evaluator
,
BaseOneShotTrainer
]
=
evaluator
self
.
applied_mutators
=
applied_mutators
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
strategy
=
strategy
# FIXME: this is only a workaround
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
if
not
isinstance
(
strategy
,
OneShotStrategy
):
if
not
isinstance
(
strategy
,
OneShotStrategy
):
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
else
:
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
url_prefix
=
None
self
.
url_prefix
=
None
...
@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
...
@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
assert
self
.
config
.
training_service
.
platform
==
'remote'
,
\
assert
self
.
config
.
training_service
.
platform
==
'remote'
,
\
"CGO execution engine currently only supports remote training service"
"CGO execution engine currently only supports remote training service"
assert
self
.
config
.
batch_waiting_time
is
not
None
assert
self
.
config
.
batch_waiting_time
is
not
None
and
self
.
config
.
max_concurrency_cgo
is
not
None
devices
=
self
.
_construct_devices
()
devices
=
self
.
_construct_devices
()
engine
=
CGOExecutionEngine
(
devices
,
engine
=
CGOExecutionEngine
(
devices
,
max_concurrency
=
self
.
config
.
max_concurrency_cgo
,
max_concurrency
=
self
.
config
.
max_concurrency_cgo
,
...
@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
...
@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
engine
=
PurePythonExecutionEngine
()
engine
=
PurePythonExecutionEngine
()
elif
self
.
config
.
execution_engine
==
'benchmark'
:
elif
self
.
config
.
execution_engine
==
'benchmark'
:
from
..execution.benchmark
import
BenchmarkExecutionEngine
from
..execution.benchmark
import
BenchmarkExecutionEngine
assert
self
.
config
.
benchmark
is
not
None
,
'"benchmark" must be set when benchmark execution engine is used.'
engine
=
BenchmarkExecutionEngine
(
self
.
config
.
benchmark
)
engine
=
BenchmarkExecutionEngine
(
self
.
config
.
benchmark
)
else
:
raise
ValueError
(
f
'Unsupported engine type:
{
self
.
config
.
execution_engine
}
'
)
set_execution_engine
(
engine
)
set_execution_engine
(
engine
)
self
.
id
=
management
.
generate_experiment_id
()
self
.
id
=
management
.
generate_experiment_id
()
...
@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
...
@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
def
_construct_devices
(
self
):
def
_construct_devices
(
self
):
devices
=
[]
devices
=
[]
if
hasattr
(
self
.
config
.
training_service
,
'machine_list'
):
if
hasattr
(
self
.
config
.
training_service
,
'machine_list'
):
for
machine
in
self
.
config
.
training_service
.
machine_list
:
for
machine
in
cast
(
RemoteConfig
,
self
.
config
.
training_service
)
.
machine_list
:
assert
machine
.
gpu_indices
is
not
None
,
\
assert
machine
.
gpu_indices
is
not
None
,
\
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert
isinstance
(
machine
.
gpu_indices
,
list
),
'gpu_indices must be a list'
for
gpu_idx
in
machine
.
gpu_indices
:
for
gpu_idx
in
machine
.
gpu_indices
:
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
return
devices
return
devices
...
@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
...
@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
config
:
Optional
[
RetiariiExeConfig
]
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
Run the experiment.
Run the experiment.
This function will block until experiment finish or error.
This function will block until experiment finish or error.
...
@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
...
@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
This function will block until experiment finish or error.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
Return `True` when experiment done; or return `False` when experiment failed.
"""
"""
assert
self
.
_proc
is
not
None
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
...
@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
...
@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
_logger
.
warning
(
'KeyboardInterrupt detected'
)
_logger
.
warning
(
'KeyboardInterrupt detected'
)
finally
:
finally
:
self
.
stop
()
self
.
stop
()
raise
RuntimeError
(
'Check experiment status failed.'
)
def
stop
(
self
)
->
None
:
def
stop
(
self
)
->
None
:
"""
"""
...
@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
...
@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
if
self
.
_pipe
is
not
None
:
if
self
.
_pipe
is
not
None
:
self
.
_pipe
.
close
()
self
.
_pipe
.
close
()
self
.
id
=
None
self
.
id
=
cast
(
str
,
None
)
self
.
port
=
None
self
.
port
=
cast
(
int
,
None
)
self
.
_proc
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
self
.
_pipe
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_dispatcher_thread
=
None
self
.
_dispatcher_thread
=
None
_logger
.
info
(
'Experiment stopped'
)
_logger
.
info
(
'Experiment stopped'
)
...
...
nni/retiarii/fixed.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
json
import
logging
import
logging
from
pathlib
import
Path
from
pathlib
import
Path
...
...
nni/retiarii/graph.py
View file @
18962129
...
@@ -5,10 +5,16 @@
...
@@ -5,10 +5,16 @@
Model representation.
Model representation.
"""
"""
from
__future__
import
annotations
import
abc
import
abc
import
json
import
json
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
overload
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
if
TYPE_CHECKING
:
from
.mutator
import
Mutator
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.utils
import
uid
from
.utils
import
uid
...
@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
...
@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_execute
(
self
,
model_cls
:
type
)
->
Any
:
def
_execute
(
self
,
model_cls
:
Union
[
Callable
[[],
Any
],
Any
]
)
->
Any
:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -203,7 +209,7 @@ class Model:
...
@@ -203,7 +209,7 @@ class Model:
matched_nodes
.
extend
(
nodes
)
matched_nodes
.
extend
(
nodes
)
return
matched_nodes
return
matched_nodes
def
get_node_by_name
(
self
,
node_name
:
str
)
->
'Node'
:
def
get_node_by_name
(
self
,
node_name
:
str
)
->
'Node'
|
None
:
"""
"""
Traverse all the nodes to find the matched node with the given name.
Traverse all the nodes to find the matched node with the given name.
"""
"""
...
@@ -217,7 +223,7 @@ class Model:
...
@@ -217,7 +223,7 @@ class Model:
else
:
else
:
return
None
return
None
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
'Node'
:
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]
:
"""
"""
Traverse all the nodes to find the matched node with the given python_name.
Traverse all the nodes to find the matched node with the given python_name.
"""
"""
...
@@ -297,7 +303,7 @@ class Graph:
...
@@ -297,7 +303,7 @@ class Graph:
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
"""
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
cast
(
str
,
None
)
,
_internal
:
bool
=
False
):
assert
_internal
,
'`Graph()` is private'
assert
_internal
,
'`Graph()` is private'
self
.
model
:
Model
=
model
self
.
model
:
Model
=
model
...
@@ -338,9 +344,9 @@ class Graph:
...
@@ -338,9 +344,9 @@ class Graph:
@
overload
@
overload
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
@
overload
@
overload
def
add_node
(
self
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
)
->
'Node'
:
...
def
add_node
(
self
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
)
->
'Node'
:
...
def
add_node
(
self
,
name
,
operation_or_type
,
parameters
=
None
):
def
add_node
(
self
,
name
,
operation_or_type
,
parameters
=
None
):
# type: ignore
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
...
@@ -350,9 +356,10 @@ class Graph:
...
@@ -350,9 +356,10 @@ class Graph:
@
overload
@
overload
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
@
overload
@
overload
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
)
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
))
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
,
name
,
operation_or_type
,
parameters
=
None
)
->
'Node'
:
def
insert_node_on_edge
(
self
,
edge
,
name
,
operation_or_type
,
parameters
=
None
)
->
'Node'
:
# type: ignore
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
...
@@ -405,7 +412,7 @@ class Graph:
...
@@ -405,7 +412,7 @@ class Graph:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
def
get_nodes_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]:
def
get_nodes_by_python_name
(
self
,
python_name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
return
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
def
topo_sort
(
self
)
->
List
[
'Node'
]:
def
topo_sort
(
self
)
->
List
[
'Node'
]:
...
@@ -594,7 +601,7 @@ class Node:
...
@@ -594,7 +601,7 @@ class Node:
return
sorted
(
set
(
edge
.
tail
for
edge
in
self
.
outgoing_edges
),
key
=
(
lambda
node
:
node
.
id
))
return
sorted
(
set
(
edge
.
tail
for
edge
in
self
.
outgoing_edges
),
key
=
(
lambda
node
:
node
.
id
))
@
property
@
property
def
successor_slots
(
self
)
->
Lis
t
[
Tuple
[
'Node'
,
Union
[
int
,
None
]]]:
def
successor_slots
(
self
)
->
Se
t
[
Tuple
[
'Node'
,
Union
[
int
,
None
]]]:
return
set
((
edge
.
tail
,
edge
.
tail_slot
)
for
edge
in
self
.
outgoing_edges
)
return
set
((
edge
.
tail
,
edge
.
tail_slot
)
for
edge
in
self
.
outgoing_edges
)
@
property
@
property
...
@@ -610,19 +617,19 @@ class Node:
...
@@ -610,19 +617,19 @@ class Node:
assert
isinstance
(
self
.
operation
,
Cell
)
assert
isinstance
(
self
.
operation
,
Cell
)
return
self
.
graph
.
model
.
graphs
[
self
.
operation
.
parameters
[
'cell'
]]
return
self
.
graph
.
model
.
graphs
[
self
.
operation
.
parameters
[
'cell'
]]
def
update_label
(
self
,
label
:
str
)
->
None
:
def
update_label
(
self
,
label
:
Optional
[
str
]
)
->
None
:
self
.
label
=
label
self
.
label
=
label
@
overload
@
overload
def
update_operation
(
self
,
operation
:
Operation
)
->
None
:
...
def
update_operation
(
self
,
operation
:
Operation
)
->
None
:
...
@
overload
@
overload
def
update_operation
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
)
->
None
:
...
def
update_operation
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
)
->
None
:
...
def
update_operation
(
self
,
operation_or_type
,
parameters
=
None
):
def
update_operation
(
self
,
operation_or_type
,
parameters
=
None
):
# type: ignore
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
self
.
operation
=
operation_or_type
self
.
operation
=
operation_or_type
else
:
else
:
self
.
operation
=
Operation
.
new
(
operation_or_type
,
parameters
)
self
.
operation
=
Operation
.
new
(
operation_or_type
,
cast
(
dict
,
parameters
)
)
# mutation
# mutation
def
remove
(
self
)
->
None
:
def
remove
(
self
)
->
None
:
...
@@ -663,7 +670,13 @@ class Node:
...
@@ -663,7 +670,13 @@ class Node:
return
node
return
node
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}}
ret
:
Dict
[
str
,
Any
]
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}
}
if
isinstance
(
self
.
operation
,
Cell
):
if
isinstance
(
self
.
operation
,
Cell
):
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
...
...
nni/retiarii/hub/pytorch/mobilenetv3.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Tuple
,
Optional
,
Callable
from
typing
import
Tuple
,
Optional
,
Callable
,
cast
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
...
@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
...
@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
bn_momentum
:
float
=
0.1
):
bn_momentum
:
float
=
0.1
):
super
().
__init__
()
super
().
__init__
()
self
.
widths
=
[
self
.
widths
=
cast
(
nn
.
ChoiceOf
[
int
],
[
nn
.
ValueChoice
([
make_divisible
(
base_width
*
mult
,
8
)
for
mult
in
width_multipliers
],
label
=
f
'width_
{
i
}
'
)
nn
.
ValueChoice
([
make_divisible
(
base_width
*
mult
,
8
)
for
mult
in
width_multipliers
],
label
=
f
'width_
{
i
}
'
)
for
i
,
base_width
in
enumerate
(
base_widths
)
for
i
,
base_width
in
enumerate
(
base_widths
)
]
]
)
self
.
expand_ratios
=
expand_ratios
self
.
expand_ratios
=
expand_ratios
blocks
=
[
blocks
=
[
...
@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
...
@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
self
.
classifier
=
nn
.
Sequential
(
self
.
classifier
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout_rate
),
nn
.
Dropout
(
dropout_rate
),
nn
.
Linear
(
self
.
widths
[
7
],
num_labels
),
nn
.
Linear
(
cast
(
int
,
self
.
widths
[
7
]
)
,
num_labels
),
)
)
reset_parameters
(
self
,
bn_momentum
=
bn_momentum
,
bn_eps
=
bn_eps
)
reset_parameters
(
self
,
bn_momentum
=
bn_momentum
,
bn_eps
=
bn_eps
)
...
...
nni/retiarii/hub/pytorch/nasbench101.py
View file @
18962129
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
math
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
...
@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
...
@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
__all__
=
[
'NasBench101'
]
__all__
=
[
'NasBench101'
]
def
truncated_normal_
(
tensor
,
mean
=
0
,
std
=
1
):
def
truncated_normal_
(
tensor
:
torch
.
Tensor
,
mean
:
float
=
0
,
std
:
float
=
1
):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size
=
tensor
.
shape
size
=
tensor
.
shape
tmp
=
tensor
.
new_empty
(
size
+
(
4
,)).
normal_
()
tmp
=
tensor
.
new_empty
(
size
+
(
4
,)).
normal_
()
...
@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
...
@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
out
=
self
.
gap
(
out
).
view
(
bs
,
-
1
)
out
=
self
.
gap
(
out
).
view
(
bs
,
-
1
)
out
=
self
.
classifier
(
out
)
out
=
self
.
classifier
(
out
)
return
out
return
out
def
reset_parameters
(
self
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
nn
.
BatchNorm2d
):
module
.
eps
=
self
.
config
.
bn_eps
module
.
momentum
=
self
.
config
.
bn_momentum
nni/retiarii/hub/pytorch/nasbench201.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Callable
,
Dict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
...
@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
if
reduction
:
if
reduction
:
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
else
:
else
:
cell
=
NasBench201Cell
({
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
},
ops
:
Dict
[
str
,
Callable
[[
int
,
int
],
nn
.
Module
]]
=
{
C_prev
,
C_curr
,
label
=
'cell'
)
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
}
cell
=
NasBench201Cell
(
ops
,
C_prev
,
C_curr
,
label
=
'cell'
)
self
.
cells
.
append
(
cell
)
self
.
cells
.
append
(
cell
)
C_prev
=
C_curr
C_prev
=
C_curr
...
...
nni/retiarii/hub/pytorch/nasnet.py
View file @
18962129
...
@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
...
@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
"""
"""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Tuple
,
List
,
Union
,
Iterable
,
Dict
,
Callable
from
typing
import
Tuple
,
List
,
Union
,
Iterable
,
Dict
,
Callable
,
Optional
,
cast
try
:
try
:
from
typing
import
Literal
from
typing
import
Literal
...
@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
...
@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
See :class:`CellBuilder` on how to calculate those channel numbers.
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
"""
def
__init__
(
self
,
C_pprev
:
int
,
C_prev
:
int
,
C
:
int
,
last_cell_reduce
:
bool
)
->
None
:
def
__init__
(
self
,
C_pprev
:
nn
.
MaybeChoice
[
int
]
,
C_prev
:
nn
.
MaybeChoice
[
int
]
,
C
:
nn
.
MaybeChoice
[
int
]
,
last_cell_reduce
:
bool
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
last_cell_reduce
:
if
last_cell_reduce
:
self
.
pre0
=
FactorizedReduce
(
C_pprev
,
C
)
self
.
pre0
=
FactorizedReduce
(
cast
(
int
,
C_pprev
),
cast
(
int
,
C
)
)
else
:
else
:
self
.
pre0
=
ReLUConvBN
(
C_pprev
,
C
,
1
,
1
,
0
)
self
.
pre0
=
ReLUConvBN
(
cast
(
int
,
C_pprev
),
cast
(
int
,
C
)
,
1
,
1
,
0
)
self
.
pre1
=
ReLUConvBN
(
C_prev
,
C
,
1
,
1
,
0
)
self
.
pre1
=
ReLUConvBN
(
cast
(
int
,
C_prev
),
cast
(
int
,
C
)
,
1
,
1
,
0
)
def
forward
(
self
,
cells
):
def
forward
(
self
,
cells
):
assert
len
(
cells
)
==
2
assert
len
(
cells
)
==
2
...
@@ -283,15 +283,19 @@ class CellBuilder:
...
@@ -283,15 +283,19 @@ class CellBuilder:
Note that the builder is ephemeral, it can only be called once for every index.
Note that the builder is ephemeral, it can only be called once for every index.
"""
"""
def
__init__
(
self
,
op_candidates
:
List
[
str
],
C_prev_in
:
int
,
C_in
:
int
,
C
:
int
,
def
__init__
(
self
,
op_candidates
:
List
[
str
],
num_nodes
:
int
,
merge_op
:
Literal
[
'all'
,
'loose_end'
],
C_prev_in
:
nn
.
MaybeChoice
[
int
],
C_in
:
nn
.
MaybeChoice
[
int
],
C
:
nn
.
MaybeChoice
[
int
],
num_nodes
:
int
,
merge_op
:
Literal
[
'all'
,
'loose_end'
],
first_cell_reduce
:
bool
,
last_cell_reduce
:
bool
):
first_cell_reduce
:
bool
,
last_cell_reduce
:
bool
):
self
.
C_prev_in
=
C_prev_in
# This is the out channels of the cell before last cell.
self
.
C_prev_in
=
C_prev_in
# This is the out channels of the cell before last cell.
self
.
C_in
=
C_in
# This is the out channesl of last cell.
self
.
C_in
=
C_in
# This is the out channesl of last cell.
self
.
C
=
C
# This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self
.
C
=
C
# This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self
.
op_candidates
=
op_candidates
self
.
op_candidates
=
op_candidates
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
self
.
merge_op
=
merge_op
self
.
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
merge_op
self
.
first_cell_reduce
=
first_cell_reduce
self
.
first_cell_reduce
=
first_cell_reduce
self
.
last_cell_reduce
=
last_cell_reduce
self
.
last_cell_reduce
=
last_cell_reduce
self
.
_expect_idx
=
0
self
.
_expect_idx
=
0
...
@@ -312,7 +316,7 @@ class CellBuilder:
...
@@ -312,7 +316,7 @@ class CellBuilder:
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor
=
CellPreprocessor
(
self
.
C_prev_in
,
self
.
C_in
,
self
.
C
,
self
.
last_cell_reduce
)
preprocessor
=
CellPreprocessor
(
self
.
C_prev_in
,
self
.
C_in
,
self
.
C
,
self
.
last_cell_reduce
)
ops_factory
:
Dict
[
str
,
Callable
[[
int
,
int
,
int
],
nn
.
Module
]]
=
{
ops_factory
:
Dict
[
str
,
Callable
[[
int
,
int
,
Optional
[
int
]
]
,
nn
.
Module
]]
=
{
op
:
# make final chosen ops named with their aliases
op
:
# make final chosen ops named with their aliases
lambda
node_index
,
op_index
,
input_index
:
lambda
node_index
,
op_index
,
input_index
:
OPS
[
op
](
self
.
C
,
2
if
is_reduction_cell
and
(
OPS
[
op
](
self
.
C
,
2
if
is_reduction_cell
and
(
...
@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
...
@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
class
NDS
(
nn
.
Module
):
class
NDS
(
nn
.
Module
):
"""
__doc__
=
"""
The unified version of NASNet search space.
The unified version of NASNet search space.
We follow the implementation in
We follow the implementation in
...
@@ -378,8 +382,8 @@ class NDS(nn.Module):
...
@@ -378,8 +382,8 @@ class NDS(nn.Module):
op_candidates
:
List
[
str
],
op_candidates
:
List
[
str
],
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
'all'
,
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
'all'
,
num_nodes_per_cell
:
int
=
4
,
num_nodes_per_cell
:
int
=
4
,
width
:
Union
[
Tuple
[
int
],
int
]
=
16
,
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
16
,
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
20
,
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
20
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'imagenet'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'imagenet'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
...
@@ -394,30 +398,31 @@ class NDS(nn.Module):
...
@@ -394,30 +398,31 @@ class NDS(nn.Module):
else
:
else
:
C
=
width
C
=
width
self
.
num_cells
:
nn
.
MaybeChoice
[
int
]
=
cast
(
int
,
num_cells
)
if
isinstance
(
num_cells
,
Iterable
):
if
isinstance
(
num_cells
,
Iterable
):
num_cells
=
nn
.
ValueChoice
(
list
(
num_cells
),
label
=
'depth'
)
self
.
num_cells
=
nn
.
ValueChoice
(
list
(
num_cells
),
label
=
'depth'
)
num_cells_per_stage
=
[
i
*
num_cells
//
3
-
(
i
-
1
)
*
num_cells
//
3
for
i
in
range
(
3
)]
num_cells_per_stage
=
[
i
*
self
.
num_cells
//
3
-
(
i
-
1
)
*
self
.
num_cells
//
3
for
i
in
range
(
3
)]
# auxiliary head is different for network targetted at different datasets
# auxiliary head is different for network targetted at different datasets
if
dataset
==
'imagenet'
:
if
dataset
==
'imagenet'
:
self
.
stem0
=
nn
.
Sequential
(
self
.
stem0
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
C
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
3
,
cast
(
int
,
C
//
2
)
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
//
2
),
nn
.
BatchNorm2d
(
cast
(
int
,
C
//
2
)
)
,
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
C
//
2
,
C
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
cast
(
int
,
C
//
2
),
cast
(
int
,
C
)
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
),
nn
.
BatchNorm2d
(
C
),
)
)
self
.
stem1
=
nn
.
Sequential
(
self
.
stem1
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
C
,
C
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
cast
(
int
,
C
),
cast
(
int
,
C
)
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
),
nn
.
BatchNorm2d
(
C
),
)
)
C_pprev
=
C_prev
=
C_curr
=
C
C_pprev
=
C_prev
=
C_curr
=
C
last_cell_reduce
=
True
last_cell_reduce
=
True
elif
dataset
==
'cifar'
:
elif
dataset
==
'cifar'
:
self
.
stem
=
nn
.
Sequential
(
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
3
*
C
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
3
,
cast
(
int
,
3
*
C
)
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
3
*
C
)
nn
.
BatchNorm2d
(
cast
(
int
,
3
*
C
)
)
)
)
C_pprev
=
C_prev
=
3
*
C
C_pprev
=
C_prev
=
3
*
C
C_curr
=
C
C_curr
=
C
...
@@ -439,7 +444,7 @@ class NDS(nn.Module):
...
@@ -439,7 +444,7 @@ class NDS(nn.Module):
# C_pprev is output channel number of last second cell among all the cells already built.
# C_pprev is output channel number of last second cell among all the cells already built.
if
len
(
stage
)
>
1
:
if
len
(
stage
)
>
1
:
# Contains more than one cell
# Contains more than one cell
C_pprev
=
len
(
stage
[
-
2
].
output_node_indices
)
*
C_curr
C_pprev
=
len
(
cast
(
nn
.
Cell
,
stage
[
-
2
]
)
.
output_node_indices
)
*
C_curr
else
:
else
:
# Look up in the out channels of last stage.
# Look up in the out channels of last stage.
C_pprev
=
C_prev
C_pprev
=
C_prev
...
@@ -447,7 +452,7 @@ class NDS(nn.Module):
...
@@ -447,7 +452,7 @@ class NDS(nn.Module):
# This was originally,
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
# but due to loose end, it becomes,
C_prev
=
len
(
stage
[
-
1
].
output_node_indices
)
*
C_curr
C_prev
=
len
(
cast
(
nn
.
Cell
,
stage
[
-
1
]
)
.
output_node_indices
)
*
C_curr
# Useful in aligning the pprev and prev cell.
# Useful in aligning the pprev and prev cell.
last_cell_reduce
=
cell_builder
.
last_cell_reduce
last_cell_reduce
=
cell_builder
.
last_cell_reduce
...
@@ -457,11 +462,11 @@ class NDS(nn.Module):
...
@@ -457,11 +462,11 @@ class NDS(nn.Module):
if
auxiliary_loss
:
if
auxiliary_loss
:
assert
isinstance
(
self
.
stages
[
2
],
nn
.
Sequential
),
'Auxiliary loss can only be enabled in retrain mode.'
assert
isinstance
(
self
.
stages
[
2
],
nn
.
Sequential
),
'Auxiliary loss can only be enabled in retrain mode.'
self
.
stages
[
2
]
=
SequentialBreakdown
(
self
.
stages
[
2
])
self
.
stages
[
2
]
=
SequentialBreakdown
(
cast
(
nn
.
Sequential
,
self
.
stages
[
2
])
)
self
.
auxiliary_head
=
AuxiliaryHead
(
C_to_auxiliary
,
self
.
num_labels
,
dataset
=
self
.
dataset
)
self
.
auxiliary_head
=
AuxiliaryHead
(
C_to_auxiliary
,
self
.
num_labels
,
dataset
=
self
.
dataset
)
# type: ignore
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
classifier
=
nn
.
Linear
(
C_prev
,
self
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
cast
(
int
,
C_prev
)
,
self
.
num_labels
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
if
self
.
dataset
==
'imagenet'
:
if
self
.
dataset
==
'imagenet'
:
...
@@ -483,7 +488,7 @@ class NDS(nn.Module):
...
@@ -483,7 +488,7 @@ class NDS(nn.Module):
out
=
self
.
global_pooling
(
s1
)
out
=
self
.
global_pooling
(
s1
)
logits
=
self
.
classifier
(
out
.
view
(
out
.
size
(
0
),
-
1
))
logits
=
self
.
classifier
(
out
.
view
(
out
.
size
(
0
),
-
1
))
if
self
.
training
and
self
.
auxiliary_loss
:
if
self
.
training
and
self
.
auxiliary_loss
:
return
logits
,
logits_aux
return
logits
,
logits_aux
# type: ignore
else
:
else
:
return
logits
return
logits
...
@@ -524,8 +529,8 @@ class NASNet(NDS):
...
@@ -524,8 +529,8 @@ class NASNet(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
NASNET_OPS
,
super
().
__init__
(
self
.
NASNET_OPS
,
...
@@ -555,8 +560,8 @@ class ENAS(NDS):
...
@@ -555,8 +560,8 @@ class ENAS(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
ENAS_OPS
,
super
().
__init__
(
self
.
ENAS_OPS
,
...
@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
...
@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
...
@@ -626,8 +631,8 @@ class PNAS(NDS):
...
@@ -626,8 +631,8 @@ class PNAS(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
PNAS_OPS
,
super
().
__init__
(
self
.
PNAS_OPS
,
...
@@ -660,8 +665,8 @@ class DARTS(NDS):
...
@@ -660,8 +665,8 @@ class DARTS(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
DARTS_OPS
,
super
().
__init__
(
self
.
DARTS_OPS
,
...
...
nni/retiarii/hub/pytorch/proxylessnas.py
View file @
18962129
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
cast
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
...
@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
...
@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
nn
.
MaybeChoice
[
int
]
,
out_channels
:
int
,
out_channels
:
nn
.
MaybeChoice
[
int
]
,
kernel_size
:
int
=
3
,
kernel_size
:
nn
.
MaybeChoice
[
int
]
=
3
,
stride
:
int
=
1
,
stride
:
int
=
1
,
groups
:
int
=
1
,
groups
:
nn
.
MaybeChoice
[
int
]
=
1
,
norm_layer
:
Optional
[
Callable
[
...
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
[
int
]
,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
dilation
:
int
=
1
,
dilation
:
int
=
1
,
)
->
None
:
)
->
None
:
...
@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
...
@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
if
activation_layer
is
None
:
if
activation_layer
is
None
:
activation_layer
=
nn
.
ReLU6
activation_layer
=
nn
.
ReLU6
super
().
__init__
(
super
().
__init__
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
groups
,
nn
.
Conv2d
(
bias
=
False
),
cast
(
int
,
in_channels
),
norm_layer
(
out_channels
),
cast
(
int
,
out_channels
),
cast
(
int
,
kernel_size
),
stride
,
cast
(
int
,
padding
),
dilation
=
dilation
,
groups
=
cast
(
int
,
groups
),
bias
=
False
),
norm_layer
(
cast
(
int
,
out_channels
)),
activation_layer
(
inplace
=
True
)
activation_layer
(
inplace
=
True
)
)
)
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
...
@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
nn
.
MaybeChoice
[
int
]
,
out_channels
:
int
,
out_channels
:
nn
.
MaybeChoice
[
int
]
,
kernel_size
:
int
=
3
,
kernel_size
:
nn
.
MaybeChoice
[
int
]
=
3
,
stride
:
int
=
1
,
stride
:
int
=
1
,
norm_layer
:
Optional
[
Callable
[
...
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
[
int
]
,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
...
@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
...
@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
nn
.
MaybeChoice
[
int
]
,
out_channels
:
int
,
out_channels
:
nn
.
MaybeChoice
[
int
]
,
expand_ratio
:
int
,
expand_ratio
:
nn
.
MaybeChoice
[
float
]
,
kernel_size
:
int
=
3
,
kernel_size
:
nn
.
MaybeChoice
[
int
]
=
3
,
stride
:
int
=
1
,
stride
:
int
=
1
,
squeeze_and_excite
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
squeeze_and_excite
:
Optional
[
Callable
[[
nn
.
MaybeChoice
[
int
]
]
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
...
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
[
int
]
,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
...
@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
assert
stride
in
[
1
,
2
]
assert
stride
in
[
1
,
2
]
hidden_ch
=
nn
.
ValueChoice
.
to_int
(
round
(
in_channels
*
expand_ratio
))
hidden_ch
=
nn
.
ValueChoice
.
to_int
(
round
(
cast
(
int
,
in_channels
*
expand_ratio
))
)
# FIXME: check whether this equal works
# FIXME: check whether this equal works
# Residual connection is added here stride = 1 and input channels and output channels are the same.
# Residual connection is added here stride = 1 and input channels and output channels are the same.
...
@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
...
@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
self
.
first_conv
=
ConvBNReLU
(
3
,
widths
[
0
],
stride
=
2
,
norm_layer
=
nn
.
BatchNorm2d
)
self
.
first_conv
=
ConvBNReLU
(
3
,
widths
[
0
],
stride
=
2
,
norm_layer
=
nn
.
BatchNorm2d
)
blocks
=
[
blocks
:
List
[
nn
.
Module
]
=
[
# first stage is fixed
# first stage is fixed
SeparableConv
(
widths
[
0
],
widths
[
1
],
kernel_size
=
3
,
stride
=
1
)
SeparableConv
(
widths
[
0
],
widths
[
1
],
kernel_size
=
3
,
stride
=
1
)
]
]
...
...
nni/retiarii/hub/pytorch/shufflenet.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
cast
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
...
@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
...
@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
int
,
*
,
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
nn
.
MaybeChoice
[
int
]
,
*
,
kernel_size
:
int
,
stride
:
int
,
sequence
:
str
=
"pdp"
,
affine
:
bool
=
True
):
kernel_size
:
int
,
stride
:
int
,
sequence
:
str
=
"pdp"
,
affine
:
bool
=
True
):
super
().
__init__
()
super
().
__init__
()
assert
stride
in
[
1
,
2
]
assert
stride
in
[
1
,
2
]
...
@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
...
@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
def
_decode_point_depth_conv
(
self
,
sequence
):
def
_decode_point_depth_conv
(
self
,
sequence
):
result
=
[]
result
=
[]
first_depth
=
first_point
=
True
first_depth
=
first_point
=
True
pc
=
c
=
self
.
channels
pc
:
int
=
self
.
channels
c
:
int
=
self
.
channels
for
i
,
token
in
enumerate
(
sequence
):
for
i
,
token
in
enumerate
(
sequence
):
# compute output channels of this conv
# compute output channels of this conv
if
i
+
1
==
len
(
sequence
):
if
i
+
1
==
len
(
sequence
):
assert
token
==
"p"
,
"Last conv must be point-wise conv."
assert
token
==
"p"
,
"Last conv must be point-wise conv."
c
=
self
.
oup_main
c
=
self
.
oup_main
elif
token
==
"p"
and
first_point
:
elif
token
==
"p"
and
first_point
:
c
=
self
.
mid_channels
c
=
cast
(
int
,
self
.
mid_channels
)
if
token
==
"d"
:
if
token
==
"d"
:
# depth-wise conv
# depth-wise conv
if
isinstance
(
pc
,
int
)
and
isinstance
(
c
,
int
):
if
isinstance
(
pc
,
int
)
and
isinstance
(
c
,
int
):
...
@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
...
@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
int
,
*
,
stride
:
int
,
affine
:
bool
=
True
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
nn
.
MaybeChoice
[
int
]
,
*
,
stride
:
int
,
affine
:
bool
=
True
):
super
().
__init__
(
in_channels
,
out_channels
,
mid_channels
,
super
().
__init__
(
in_channels
,
out_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
sequence
=
"dpdpdp"
,
affine
=
affine
)
kernel_size
=
3
,
stride
=
stride
,
sequence
=
"dpdpdp"
,
affine
=
affine
)
...
@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
...
@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
)
)
self
.
features
=
[]
feature
_block
s
=
[]
global_block_idx
=
0
global_block_idx
=
0
for
stage_idx
,
num_repeat
in
enumerate
(
self
.
stage_repeats
):
for
stage_idx
,
num_repeat
in
enumerate
(
self
.
stage_repeats
):
...
@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
...
@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
else
:
else
:
mid_channels
=
int
(
base_mid_channels
)
mid_channels
=
int
(
base_mid_channels
)
mid_channels
=
cast
(
nn
.
MaybeChoice
[
int
],
mid_channels
)
choice_block
=
nn
.
LayerChoice
([
choice_block
=
nn
.
LayerChoice
([
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
5
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
5
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
7
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
7
,
stride
=
stride
,
affine
=
affine
),
ShuffleXceptionBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
stride
=
stride
,
affine
=
affine
)
ShuffleXceptionBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
stride
=
stride
,
affine
=
affine
)
],
label
=
f
'layer_
{
global_block_idx
}
'
)
],
label
=
f
'layer_
{
global_block_idx
}
'
)
self
.
features
.
append
(
choice_block
)
feature
_block
s
.
append
(
choice_block
)
self
.
features
=
nn
.
Sequential
(
*
self
.
features
)
self
.
features
=
nn
.
Sequential
(
*
feature
_block
s
)
# final layers
# final layers
last_conv_channels
=
self
.
stage_out_channels
[
-
1
]
last_conv_channels
=
self
.
stage_out_channels
[
-
1
]
...
@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
...
@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
if
m
.
running_mean
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
if
m
.
weight
is
not
None
:
if
m
.
weight
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
if
m
.
running_mean
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
):
elif
isinstance
(
m
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
...
...
nni/retiarii/hub/pytorch/utils.py
0 → 100644
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
nni/retiarii/integration.py
View file @
18962129
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
logging
import
logging
import
os
import
os
from
typing
import
Any
,
Callable
from
typing
import
Any
,
Callable
,
Optional
import
nni
import
nni
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.serializer
import
PayloadTooLarge
...
@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
self
.
search_space
=
None
self
.
search_space
=
None
self
.
send_trial_callback
:
Callable
[[
dict
],
None
]
=
None
self
.
send_trial_callback
:
Optional
[
Callable
[[
dict
],
None
]
]
=
None
self
.
request_trial_jobs_callback
:
Callable
[[
int
],
None
]
=
None
self
.
request_trial_jobs_callback
:
Optional
[
Callable
[[
int
],
None
]
]
=
None
self
.
trial_end_callback
:
Callable
[[
int
,
bool
],
None
]
=
None
self
.
trial_end_callback
:
Optional
[
Callable
[[
int
,
bool
],
None
]
]
=
None
self
.
intermediate_metric_callback
:
Callable
[[
int
,
MetricData
],
None
]
=
None
self
.
intermediate_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]
]
=
None
self
.
final_metric_callback
:
Callable
[[
int
,
MetricData
],
None
]
=
None
self
.
final_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]
]
=
None
self
.
parameters_count
=
0
self
.
parameters_count
=
0
...
@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
_logger
.
debug
(
'Trial end: %s'
,
data
)
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
trial_end_callback
(
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
if
self
.
trial_end_callback
is
not
None
:
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
trial_end_callback
(
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
_logger
.
debug
(
'Metric reported: %s'
,
data
)
_logger
.
debug
(
'Metric reported: %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
raise
ValueError
(
'Request parameter not supported'
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
intermediate_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
if
self
.
intermediate_metric_callback
is
not
None
:
self
.
_process_value
(
data
[
'value'
]))
self
.
intermediate_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
self
.
_process_value
(
data
[
'value'
]))
elif
data
[
'type'
]
==
MetricType
.
FINAL
:
elif
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
final_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
if
self
.
final_metric_callback
is
not
None
:
self
.
_process_value
(
data
[
'value'
]))
self
.
final_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
self
.
_process_value
(
data
[
'value'
]))
@
staticmethod
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
def
_process_value
(
value
)
->
Any
:
# hopefully a float
...
...
nni/retiarii/mutator.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
,
Tuple
)
import
warnings
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
)
from
.graph
import
Model
,
Mutation
,
ModelStatus
from
.graph
import
Model
,
Mutation
,
ModelStatus
...
@@ -44,9 +45,11 @@ class Mutator:
...
@@ -44,9 +45,11 @@ class Mutator:
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
"""
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
,
label
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
,
label
:
str
=
cast
(
str
,
None
)
)
:
self
.
sampler
:
Optional
[
Sampler
]
=
sampler
self
.
sampler
:
Optional
[
Sampler
]
=
sampler
self
.
label
:
Optional
[
str
]
=
label
if
label
is
None
:
warnings
.
warn
(
'Each mutator should have an explicit label. Mutator without label is deprecated.'
,
DeprecationWarning
)
self
.
label
:
str
=
label
self
.
_cur_model
:
Optional
[
Model
]
=
None
self
.
_cur_model
:
Optional
[
Model
]
=
None
self
.
_cur_choice_idx
:
Optional
[
int
]
=
None
self
.
_cur_choice_idx
:
Optional
[
int
]
=
None
...
...
nni/retiarii/nn/pytorch/api.py
View file @
18962129
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment