Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
354 additions
and
270 deletions
+354
-270
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+19
-15
nni/retiarii/execution/__init__.py
nni/retiarii/execution/__init__.py
+3
-0
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+1
-0
nni/retiarii/execution/benchmark.py
nni/retiarii/execution/benchmark.py
+9
-2
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+0
-1
nni/retiarii/execution/python.py
nni/retiarii/execution/python.py
+5
-1
nni/retiarii/execution/utils.py
nni/retiarii/execution/utils.py
+3
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+27
-16
nni/retiarii/fixed.py
nni/retiarii/fixed.py
+3
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+29
-16
nni/retiarii/hub/pytorch/mobilenetv3.py
nni/retiarii/hub/pytorch/mobilenetv3.py
+4
-4
nni/retiarii/hub/pytorch/nasbench101.py
nni/retiarii/hub/pytorch/nasbench101.py
+2
-7
nni/retiarii/hub/pytorch/nasbench201.py
nni/retiarii/hub/pytorch/nasbench201.py
+6
-2
nni/retiarii/hub/pytorch/nasnet.py
nni/retiarii/hub/pytorch/nasnet.py
+41
-36
nni/retiarii/hub/pytorch/proxylessnas.py
nni/retiarii/hub/pytorch/proxylessnas.py
+29
-21
nni/retiarii/hub/pytorch/shufflenet.py
nni/retiarii/hub/pytorch/shufflenet.py
+16
-9
nni/retiarii/hub/pytorch/utils.py
nni/retiarii/hub/pytorch/utils.py
+5
-0
nni/retiarii/integration.py
nni/retiarii/integration.py
+15
-12
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+6
-3
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+131
-125
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
18962129
...
@@ -4,10 +4,11 @@
...
@@ -4,10 +4,11 @@
import
os
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Callable
from
typing
import
Dict
,
Union
,
Optional
,
List
,
Callable
,
Type
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
nn_functional
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torchmetrics
import
torchmetrics
import
torch.utils.data
as
torch_data
import
torch.utils.data
as
torch_data
...
@@ -124,12 +125,12 @@ class Lightning(Evaluator):
...
@@ -124,12 +125,12 @@ class Lightning(Evaluator):
if
other
is
None
:
if
other
is
None
:
return
False
return
False
if
hasattr
(
self
,
"function"
)
and
hasattr
(
other
,
"function"
):
if
hasattr
(
self
,
"function"
)
and
hasattr
(
other
,
"function"
):
eq_func
=
(
self
.
function
==
other
.
function
)
eq_func
=
getattr
(
self
,
"
function
"
)
==
getattr
(
other
,
"
function
"
)
elif
not
(
hasattr
(
self
,
"function"
)
or
hasattr
(
other
,
"function"
)):
elif
not
(
hasattr
(
self
,
"function"
)
or
hasattr
(
other
,
"function"
)):
eq_func
=
True
eq_func
=
True
if
hasattr
(
self
,
"arguments"
)
and
hasattr
(
other
,
"arguments"
):
if
hasattr
(
self
,
"arguments"
)
and
hasattr
(
other
,
"arguments"
):
eq_args
=
(
self
.
arguments
==
other
.
arguments
)
eq_args
=
getattr
(
self
,
"
arguments
"
)
==
getattr
(
other
,
"
arguments
"
)
elif
not
(
hasattr
(
self
,
"arguments"
)
or
hasattr
(
other
,
"arguments"
)):
elif
not
(
hasattr
(
self
,
"arguments"
)
or
hasattr
(
other
,
"arguments"
)):
eq_args
=
True
eq_args
=
True
...
@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
...
@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
### The following are some commonly used Lightning modules ###
class
_SupervisedLearningModule
(
LightningModule
):
class
_SupervisedLearningModule
(
LightningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
,
metrics
:
Dict
[
str
,
torchmetrics
.
Metric
],
trainer
:
pl
.
Trainer
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
],
metrics
:
Dict
[
str
,
Type
[
torchmetrics
.
Metric
]],
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
Union
[
Path
,
str
,
bool
,
None
]
=
None
):
export_onnx
:
Union
[
Path
,
str
,
bool
,
None
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
save_hyperparameters
(
'criterion'
,
'optimizer'
,
'learning_rate'
,
'weight_decay'
)
self
.
save_hyperparameters
(
'criterion'
,
'optimizer'
,
'learning_rate'
,
'weight_decay'
)
...
@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
self
.
log
(
'test_'
+
name
,
metric
(
y_hat
,
y
),
prog_bar
=
True
)
self
.
log
(
'test_'
+
name
,
metric
(
y_hat
,
y
),
prog_bar
=
True
)
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
...
@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
class
_AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
class
_AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
return
super
().
update
(
nn
_
functional
.
softmax
(
pred
),
target
)
@
nni
.
trace
@
nni
.
trace
class
_ClassificationModule
(
_SupervisedLearningModule
):
class
_ClassificationModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
super
().
__init__
(
criterion
,
{
'acc'
:
_AccuracyWithLogits
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
...
@@ -275,10 +279,10 @@ class Classification(Lightning):
...
@@ -275,10 +279,10 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
CrossEntropyLoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
CrossEntropyLoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
export_onnx
:
bool
=
True
,
...
@@ -291,10 +295,10 @@ class Classification(Lightning):
...
@@ -291,10 +295,10 @@ class Classification(Lightning):
@
nni
.
trace
@
nni
.
trace
class
_RegressionModule
(
_SupervisedLearningModule
):
class
_RegressionModule
(
_SupervisedLearningModule
):
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
export_onnx
:
bool
=
True
):
export_onnx
:
bool
=
True
):
super
().
__init__
(
criterion
,
{
'mse'
:
torchmetrics
.
MeanSquaredError
},
super
().
__init__
(
criterion
,
{
'mse'
:
torchmetrics
.
MeanSquaredError
},
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
learning_rate
=
learning_rate
,
weight_decay
=
weight_decay
,
optimizer
=
optimizer
,
...
@@ -328,10 +332,10 @@ class Regression(Lightning):
...
@@ -328,10 +332,10 @@ class Regression(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
"""
def
__init__
(
self
,
criterion
:
nn
.
Module
=
nn
.
MSELoss
,
def
__init__
(
self
,
criterion
:
Type
[
nn
.
Module
]
=
nn
.
MSELoss
,
learning_rate
:
float
=
0.001
,
learning_rate
:
float
=
0.001
,
weight_decay
:
float
=
0.
,
weight_decay
:
float
=
0.
,
optimizer
:
optim
.
Optimizer
=
optim
.
Adam
,
optimizer
:
Type
[
optim
.
Optimizer
]
=
optim
.
Adam
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
train_dataloader
:
Optional
[
DataLoader
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
val_dataloaders
:
Union
[
DataLoader
,
List
[
DataLoader
],
None
]
=
None
,
export_onnx
:
bool
=
True
,
export_onnx
:
bool
=
True
,
...
...
nni/retiarii/execution/__init__.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.api
import
*
from
.api
import
*
nni/retiarii/execution/base.py
View file @
18962129
...
@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@
classmethod
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation_summary
=
get_mutation_summary
(
model
)
mutation_summary
=
get_mutation_summary
(
model
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator can not be None'
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
@
classmethod
@
classmethod
...
...
nni/retiarii/execution/benchmark.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
os
import
random
import
random
from
typing
import
Dict
,
Any
,
List
,
Optional
,
Union
,
Tuple
,
Callable
,
Iterable
from
typing
import
Dict
,
Any
,
List
,
Optional
,
Union
,
Tuple
,
Callable
,
Iterable
,
cast
from
..graph
import
Model
from
..graph
import
Model
from
..integration_api
import
receive_trial_parameters
from
..integration_api
import
receive_trial_parameters
...
@@ -39,6 +42,9 @@ class BenchmarkGraphData:
...
@@ -39,6 +42,9 @@ class BenchmarkGraphData:
def
load
(
data
)
->
'BenchmarkGraphData'
:
def
load
(
data
)
->
'BenchmarkGraphData'
:
return
BenchmarkGraphData
(
data
[
'mutation'
],
data
[
'benchmark'
],
data
[
'metric_name'
],
data
[
'db_path'
])
return
BenchmarkGraphData
(
data
[
'mutation'
],
data
[
'benchmark'
],
data
[
'metric_name'
],
data
[
'db_path'
])
def
__repr__
(
self
)
->
str
:
return
f
"BenchmarkGraphData(
{
self
.
mutation
}
,
{
self
.
benchmark
}
,
{
self
.
db_path
}
)"
class
BenchmarkExecutionEngine
(
BaseExecutionEngine
):
class
BenchmarkExecutionEngine
(
BaseExecutionEngine
):
"""
"""
...
@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
...
@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
graph_data
=
BenchmarkGraphData
.
load
(
receive_trial_parameters
())
graph_data
=
BenchmarkGraphData
.
load
(
receive_trial_parameters
())
assert
graph_data
.
db_path
is
not
None
,
f
'Invalid graph data because db_path is None:
{
graph_data
}
'
os
.
environ
[
'NASBENCHMARK_DIR'
]
=
graph_data
.
db_path
os
.
environ
[
'NASBENCHMARK_DIR'
]
=
graph_data
.
db_path
final
,
intermediates
=
cls
.
query_in_benchmark
(
graph_data
)
final
,
intermediates
=
cls
.
query_in_benchmark
(
graph_data
)
...
@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
...
@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
arch
=
t
arch
=
t
if
arch
is
None
:
if
arch
is
None
:
raise
ValueError
(
f
'Cannot identify architecture from mutation dict:
{
graph_data
.
mutation
}
'
)
raise
ValueError
(
f
'Cannot identify architecture from mutation dict:
{
graph_data
.
mutation
}
'
)
print
(
arch
)
return
_convert_to_final_and_intermediates
(
return
_convert_to_final_and_intermediates
(
query_nb101_trial_stats
(
arch
,
108
,
include_intermediates
=
True
),
query_nb101_trial_stats
(
arch
,
108
,
include_intermediates
=
True
),
'valid_acc'
'valid_acc'
...
@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
...
@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
benchmark_result
=
random
.
choice
(
benchmark_result
)
benchmark_result
=
random
.
choice
(
benchmark_result
)
else
:
else
:
benchmark_result
=
benchmark_result
[
0
]
benchmark_result
=
benchmark_result
[
0
]
benchmark_result
=
cast
(
dict
,
benchmark_result
)
return
benchmark_result
[
metric_name
],
[
i
[
metric_name
]
for
i
in
benchmark_result
[
'intermediates'
]
if
i
[
metric_name
]
is
not
None
]
return
benchmark_result
[
metric_name
],
[
i
[
metric_name
]
for
i
in
benchmark_result
[
'intermediates'
]
if
i
[
metric_name
]
is
not
None
]
nni/retiarii/execution/cgo_engine.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
os
import
os
import
random
import
random
...
...
nni/retiarii/execution/python.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Dict
,
Any
,
Type
from
typing
import
Dict
,
Any
,
Type
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
...
@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@
classmethod
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation
=
get_mutation_dict
(
model
)
mutation
=
get_mutation_dict
(
model
)
graph_data
=
PythonGraphData
(
model
.
python_class
,
model
.
python_init_params
,
mutation
,
model
.
evaluator
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator is not available.'
graph_data
=
PythonGraphData
(
model
.
python_class
,
model
.
python_init_params
or
{},
mutation
,
model
.
evaluator
)
return
graph_data
return
graph_data
@
classmethod
@
classmethod
...
...
nni/retiarii/execution/utils.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
List
from
typing
import
Any
,
List
from
..graph
import
Model
from
..graph
import
Model
...
...
nni/retiarii/experiment/pytorch.py
View file @
18962129
...
@@ -11,7 +11,7 @@ from dataclasses import dataclass
...
@@ -11,7 +11,7 @@ from dataclasses import dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
subprocess
import
Popen
from
subprocess
import
Popen
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
,
cast
import
colorama
import
colorama
import
psutil
import
psutil
...
@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
...
@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
from
nni.experiment.config
import
utils
from
nni.experiment.config
import
utils
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.base
import
ConfigBase
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.config.training_service
import
TrainingServiceConfig
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.experiment.pipe
import
Pipe
from
nni.experiment.pipe
import
Pipe
from
nni.tools.nnictl.command_utils
import
kill_command
from
nni.tools.nnictl.command_utils
import
kill_command
...
@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
...
@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
Examples
Examples
--------
--------
Multi-trial NAS:
Multi-trial NAS:
>>> base_model = Net()
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
...
@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
...
@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config, 8081)
>>> exp.run(exp_config, 8081)
One-shot NAS:
One-shot NAS:
>>> base_model = Net()
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
...
@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
...
@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config)
>>> exp.run(exp_config)
Export top models:
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
... final_model = Net()
"""
"""
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
None
,
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
cast
(
Evaluator
,
None
)
,
applied_mutators
:
List
[
Mutator
]
=
None
,
strategy
:
BaseStrategy
=
None
,
applied_mutators
:
List
[
Mutator
]
=
cast
(
List
[
Mutator
],
None
)
,
strategy
:
BaseStrategy
=
cast
(
BaseStrategy
,
None
)
,
trainer
:
BaseOneShotTrainer
=
None
):
trainer
:
BaseOneShotTrainer
=
cast
(
BaseOneShotTrainer
,
None
)
)
:
if
trainer
is
not
None
:
if
trainer
is
not
None
:
warnings
.
warn
(
'Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
warnings
.
warn
(
'Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.'
,
DeprecationWarning
)
'Please consider specifying it as a positional argument, or use `evaluator`.'
,
DeprecationWarning
)
...
@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
...
@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
raise
ValueError
(
'Evaluator should not be none.'
)
raise
ValueError
(
'Evaluator should not be none.'
)
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self
.
config
:
RetiariiExeConfig
=
None
self
.
config
:
RetiariiExeConfig
=
cast
(
RetiariiExeConfig
,
None
)
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
base_model
=
base_model
self
.
evaluator
:
Evaluator
=
evaluator
self
.
evaluator
:
Union
[
Evaluator
,
BaseOneShotTrainer
]
=
evaluator
self
.
applied_mutators
=
applied_mutators
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
strategy
=
strategy
# FIXME: this is only a workaround
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
from
nni.retiarii.oneshot.pytorch.strategy
import
OneShotStrategy
if
not
isinstance
(
strategy
,
OneShotStrategy
):
if
not
isinstance
(
strategy
,
OneShotStrategy
):
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
else
:
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
url_prefix
=
None
self
.
url_prefix
=
None
...
@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
...
@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
assert
self
.
config
.
training_service
.
platform
==
'remote'
,
\
assert
self
.
config
.
training_service
.
platform
==
'remote'
,
\
"CGO execution engine currently only supports remote training service"
"CGO execution engine currently only supports remote training service"
assert
self
.
config
.
batch_waiting_time
is
not
None
assert
self
.
config
.
batch_waiting_time
is
not
None
and
self
.
config
.
max_concurrency_cgo
is
not
None
devices
=
self
.
_construct_devices
()
devices
=
self
.
_construct_devices
()
engine
=
CGOExecutionEngine
(
devices
,
engine
=
CGOExecutionEngine
(
devices
,
max_concurrency
=
self
.
config
.
max_concurrency_cgo
,
max_concurrency
=
self
.
config
.
max_concurrency_cgo
,
...
@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
...
@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
engine
=
PurePythonExecutionEngine
()
engine
=
PurePythonExecutionEngine
()
elif
self
.
config
.
execution_engine
==
'benchmark'
:
elif
self
.
config
.
execution_engine
==
'benchmark'
:
from
..execution.benchmark
import
BenchmarkExecutionEngine
from
..execution.benchmark
import
BenchmarkExecutionEngine
assert
self
.
config
.
benchmark
is
not
None
,
'"benchmark" must be set when benchmark execution engine is used.'
engine
=
BenchmarkExecutionEngine
(
self
.
config
.
benchmark
)
engine
=
BenchmarkExecutionEngine
(
self
.
config
.
benchmark
)
else
:
raise
ValueError
(
f
'Unsupported engine type:
{
self
.
config
.
execution_engine
}
'
)
set_execution_engine
(
engine
)
set_execution_engine
(
engine
)
self
.
id
=
management
.
generate_experiment_id
()
self
.
id
=
management
.
generate_experiment_id
()
...
@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
...
@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
def
_construct_devices
(
self
):
def
_construct_devices
(
self
):
devices
=
[]
devices
=
[]
if
hasattr
(
self
.
config
.
training_service
,
'machine_list'
):
if
hasattr
(
self
.
config
.
training_service
,
'machine_list'
):
for
machine
in
self
.
config
.
training_service
.
machine_list
:
for
machine
in
cast
(
RemoteConfig
,
self
.
config
.
training_service
)
.
machine_list
:
assert
machine
.
gpu_indices
is
not
None
,
\
assert
machine
.
gpu_indices
is
not
None
,
\
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert
isinstance
(
machine
.
gpu_indices
,
list
),
'gpu_indices must be a list'
for
gpu_idx
in
machine
.
gpu_indices
:
for
gpu_idx
in
machine
.
gpu_indices
:
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
devices
.
append
(
GPUDevice
(
machine
.
host
,
gpu_idx
))
return
devices
return
devices
...
@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
...
@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
config
:
Optional
[
RetiariiExeConfig
]
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
Run the experiment.
Run the experiment.
This function will block until experiment finish or error.
This function will block until experiment finish or error.
...
@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
...
@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
This function will block until experiment finish or error.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
Return `True` when experiment done; or return `False` when experiment failed.
"""
"""
assert
self
.
_proc
is
not
None
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
...
@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
...
@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
_logger
.
warning
(
'KeyboardInterrupt detected'
)
_logger
.
warning
(
'KeyboardInterrupt detected'
)
finally
:
finally
:
self
.
stop
()
self
.
stop
()
raise
RuntimeError
(
'Check experiment status failed.'
)
def
stop
(
self
)
->
None
:
def
stop
(
self
)
->
None
:
"""
"""
...
@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
...
@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
if
self
.
_pipe
is
not
None
:
if
self
.
_pipe
is
not
None
:
self
.
_pipe
.
close
()
self
.
_pipe
.
close
()
self
.
id
=
None
self
.
id
=
cast
(
str
,
None
)
self
.
port
=
None
self
.
port
=
cast
(
int
,
None
)
self
.
_proc
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
self
.
_pipe
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher
=
cast
(
RetiariiAdvisor
,
None
)
self
.
_dispatcher_thread
=
None
self
.
_dispatcher_thread
=
None
_logger
.
info
(
'Experiment stopped'
)
_logger
.
info
(
'Experiment stopped'
)
...
...
nni/retiarii/fixed.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
json
import
logging
import
logging
from
pathlib
import
Path
from
pathlib
import
Path
...
...
nni/retiarii/graph.py
View file @
18962129
...
@@ -5,10 +5,16 @@
...
@@ -5,10 +5,16 @@
Model representation.
Model representation.
"""
"""
from
__future__
import
annotations
import
abc
import
abc
import
json
import
json
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
overload
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
if
TYPE_CHECKING
:
from
.mutator
import
Mutator
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.utils
import
uid
from
.utils
import
uid
...
@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
...
@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_execute
(
self
,
model_cls
:
type
)
->
Any
:
def
_execute
(
self
,
model_cls
:
Union
[
Callable
[[],
Any
],
Any
]
)
->
Any
:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -203,7 +209,7 @@ class Model:
...
@@ -203,7 +209,7 @@ class Model:
matched_nodes
.
extend
(
nodes
)
matched_nodes
.
extend
(
nodes
)
return
matched_nodes
return
matched_nodes
def
get_node_by_name
(
self
,
node_name
:
str
)
->
'Node'
:
def
get_node_by_name
(
self
,
node_name
:
str
)
->
'Node'
|
None
:
"""
"""
Traverse all the nodes to find the matched node with the given name.
Traverse all the nodes to find the matched node with the given name.
"""
"""
...
@@ -217,7 +223,7 @@ class Model:
...
@@ -217,7 +223,7 @@ class Model:
else
:
else
:
return
None
return
None
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
'Node'
:
def
get_node_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]
:
"""
"""
Traverse all the nodes to find the matched node with the given python_name.
Traverse all the nodes to find the matched node with the given python_name.
"""
"""
...
@@ -297,7 +303,7 @@ class Graph:
...
@@ -297,7 +303,7 @@ class Graph:
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
"""
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
cast
(
str
,
None
)
,
_internal
:
bool
=
False
):
assert
_internal
,
'`Graph()` is private'
assert
_internal
,
'`Graph()` is private'
self
.
model
:
Model
=
model
self
.
model
:
Model
=
model
...
@@ -338,9 +344,9 @@ class Graph:
...
@@ -338,9 +344,9 @@ class Graph:
@
overload
@
overload
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
def
add_node
(
self
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
@
overload
@
overload
def
add_node
(
self
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
)
->
'Node'
:
...
def
add_node
(
self
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
)
->
'Node'
:
...
def
add_node
(
self
,
name
,
operation_or_type
,
parameters
=
None
):
def
add_node
(
self
,
name
,
operation_or_type
,
parameters
=
None
):
# type: ignore
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
...
@@ -350,9 +356,10 @@ class Graph:
...
@@ -350,9 +356,10 @@ class Graph:
@
overload
@
overload
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
@
overload
@
overload
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
)
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
))
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
,
name
,
operation_or_type
,
parameters
=
None
)
->
'Node'
:
def
insert_node_on_edge
(
self
,
edge
,
name
,
operation_or_type
,
parameters
=
None
)
->
'Node'
:
# type: ignore
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
...
@@ -405,7 +412,7 @@ class Graph:
...
@@ -405,7 +412,7 @@ class Graph:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
def
get_nodes_by_python_name
(
self
,
python_name
:
str
)
->
Optional
[
'Node'
]:
def
get_nodes_by_python_name
(
self
,
python_name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
return
[
node
for
node
in
self
.
nodes
if
node
.
python_name
==
python_name
]
def
topo_sort
(
self
)
->
List
[
'Node'
]:
def
topo_sort
(
self
)
->
List
[
'Node'
]:
...
@@ -594,7 +601,7 @@ class Node:
...
@@ -594,7 +601,7 @@ class Node:
return
sorted
(
set
(
edge
.
tail
for
edge
in
self
.
outgoing_edges
),
key
=
(
lambda
node
:
node
.
id
))
return
sorted
(
set
(
edge
.
tail
for
edge
in
self
.
outgoing_edges
),
key
=
(
lambda
node
:
node
.
id
))
@
property
@
property
def
successor_slots
(
self
)
->
Lis
t
[
Tuple
[
'Node'
,
Union
[
int
,
None
]]]:
def
successor_slots
(
self
)
->
Se
t
[
Tuple
[
'Node'
,
Union
[
int
,
None
]]]:
return
set
((
edge
.
tail
,
edge
.
tail_slot
)
for
edge
in
self
.
outgoing_edges
)
return
set
((
edge
.
tail
,
edge
.
tail_slot
)
for
edge
in
self
.
outgoing_edges
)
@
property
@
property
...
@@ -610,19 +617,19 @@ class Node:
...
@@ -610,19 +617,19 @@ class Node:
assert
isinstance
(
self
.
operation
,
Cell
)
assert
isinstance
(
self
.
operation
,
Cell
)
return
self
.
graph
.
model
.
graphs
[
self
.
operation
.
parameters
[
'cell'
]]
return
self
.
graph
.
model
.
graphs
[
self
.
operation
.
parameters
[
'cell'
]]
def
update_label
(
self
,
label
:
str
)
->
None
:
def
update_label
(
self
,
label
:
Optional
[
str
]
)
->
None
:
self
.
label
=
label
self
.
label
=
label
@
overload
@
overload
def
update_operation
(
self
,
operation
:
Operation
)
->
None
:
...
def
update_operation
(
self
,
operation
:
Operation
)
->
None
:
...
@
overload
@
overload
def
update_operation
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
)
->
None
:
...
def
update_operation
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)
)
->
None
:
...
def
update_operation
(
self
,
operation_or_type
,
parameters
=
None
):
def
update_operation
(
self
,
operation_or_type
,
parameters
=
None
):
# type: ignore
if
isinstance
(
operation_or_type
,
Operation
):
if
isinstance
(
operation_or_type
,
Operation
):
self
.
operation
=
operation_or_type
self
.
operation
=
operation_or_type
else
:
else
:
self
.
operation
=
Operation
.
new
(
operation_or_type
,
parameters
)
self
.
operation
=
Operation
.
new
(
operation_or_type
,
cast
(
dict
,
parameters
)
)
# mutation
# mutation
def
remove
(
self
)
->
None
:
def
remove
(
self
)
->
None
:
...
@@ -663,7 +670,13 @@ class Node:
...
@@ -663,7 +670,13 @@ class Node:
return
node
return
node
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}}
ret
:
Dict
[
str
,
Any
]
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}
}
if
isinstance
(
self
.
operation
,
Cell
):
if
isinstance
(
self
.
operation
,
Cell
):
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
...
...
nni/retiarii/hub/pytorch/mobilenetv3.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Tuple
,
Optional
,
Callable
from
typing
import
Tuple
,
Optional
,
Callable
,
cast
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
...
@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
...
@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
bn_momentum
:
float
=
0.1
):
bn_momentum
:
float
=
0.1
):
super
().
__init__
()
super
().
__init__
()
self
.
widths
=
[
self
.
widths
=
cast
(
nn
.
ChoiceOf
[
int
],
[
nn
.
ValueChoice
([
make_divisible
(
base_width
*
mult
,
8
)
for
mult
in
width_multipliers
],
label
=
f
'width_
{
i
}
'
)
nn
.
ValueChoice
([
make_divisible
(
base_width
*
mult
,
8
)
for
mult
in
width_multipliers
],
label
=
f
'width_
{
i
}
'
)
for
i
,
base_width
in
enumerate
(
base_widths
)
for
i
,
base_width
in
enumerate
(
base_widths
)
]
]
)
self
.
expand_ratios
=
expand_ratios
self
.
expand_ratios
=
expand_ratios
blocks
=
[
blocks
=
[
...
@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
...
@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
self
.
classifier
=
nn
.
Sequential
(
self
.
classifier
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout_rate
),
nn
.
Dropout
(
dropout_rate
),
nn
.
Linear
(
self
.
widths
[
7
],
num_labels
),
nn
.
Linear
(
cast
(
int
,
self
.
widths
[
7
]
)
,
num_labels
),
)
)
reset_parameters
(
self
,
bn_momentum
=
bn_momentum
,
bn_eps
=
bn_eps
)
reset_parameters
(
self
,
bn_momentum
=
bn_momentum
,
bn_eps
=
bn_eps
)
...
...
nni/retiarii/hub/pytorch/nasbench101.py
View file @
18962129
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
math
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
...
@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
...
@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
__all__
=
[
'NasBench101'
]
__all__
=
[
'NasBench101'
]
def
truncated_normal_
(
tensor
,
mean
=
0
,
std
=
1
):
def
truncated_normal_
(
tensor
:
torch
.
Tensor
,
mean
:
float
=
0
,
std
:
float
=
1
):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size
=
tensor
.
shape
size
=
tensor
.
shape
tmp
=
tensor
.
new_empty
(
size
+
(
4
,)).
normal_
()
tmp
=
tensor
.
new_empty
(
size
+
(
4
,)).
normal_
()
...
@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
...
@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
out
=
self
.
gap
(
out
).
view
(
bs
,
-
1
)
out
=
self
.
gap
(
out
).
view
(
bs
,
-
1
)
out
=
self
.
classifier
(
out
)
out
=
self
.
classifier
(
out
)
return
out
return
out
def
reset_parameters
(
self
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
nn
.
BatchNorm2d
):
module
.
eps
=
self
.
config
.
bn_eps
module
.
momentum
=
self
.
config
.
bn_momentum
nni/retiarii/hub/pytorch/nasbench201.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Callable
,
Dict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
...
@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
if
reduction
:
if
reduction
:
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
else
:
else
:
cell
=
NasBench201Cell
({
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
},
ops
:
Dict
[
str
,
Callable
[[
int
,
int
],
nn
.
Module
]]
=
{
C_prev
,
C_curr
,
label
=
'cell'
)
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
}
cell
=
NasBench201Cell
(
ops
,
C_prev
,
C_curr
,
label
=
'cell'
)
self
.
cells
.
append
(
cell
)
self
.
cells
.
append
(
cell
)
C_prev
=
C_curr
C_prev
=
C_curr
...
...
nni/retiarii/hub/pytorch/nasnet.py
View file @
18962129
...
@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
...
@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
"""
"""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Tuple
,
List
,
Union
,
Iterable
,
Dict
,
Callable
from
typing
import
Tuple
,
List
,
Union
,
Iterable
,
Dict
,
Callable
,
Optional
,
cast
try
:
try
:
from
typing
import
Literal
from
typing
import
Literal
...
@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
...
@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
See :class:`CellBuilder` on how to calculate those channel numbers.
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
"""
def
__init__
(
self
,
C_pprev
:
int
,
C_prev
:
int
,
C
:
int
,
last_cell_reduce
:
bool
)
->
None
:
def
__init__
(
self
,
C_pprev
:
nn
.
MaybeChoice
[
int
]
,
C_prev
:
nn
.
MaybeChoice
[
int
]
,
C
:
nn
.
MaybeChoice
[
int
]
,
last_cell_reduce
:
bool
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
last_cell_reduce
:
if
last_cell_reduce
:
self
.
pre0
=
FactorizedReduce
(
C_pprev
,
C
)
self
.
pre0
=
FactorizedReduce
(
cast
(
int
,
C_pprev
),
cast
(
int
,
C
)
)
else
:
else
:
self
.
pre0
=
ReLUConvBN
(
C_pprev
,
C
,
1
,
1
,
0
)
self
.
pre0
=
ReLUConvBN
(
cast
(
int
,
C_pprev
),
cast
(
int
,
C
)
,
1
,
1
,
0
)
self
.
pre1
=
ReLUConvBN
(
C_prev
,
C
,
1
,
1
,
0
)
self
.
pre1
=
ReLUConvBN
(
cast
(
int
,
C_prev
),
cast
(
int
,
C
)
,
1
,
1
,
0
)
def
forward
(
self
,
cells
):
def
forward
(
self
,
cells
):
assert
len
(
cells
)
==
2
assert
len
(
cells
)
==
2
...
@@ -283,15 +283,19 @@ class CellBuilder:
...
@@ -283,15 +283,19 @@ class CellBuilder:
Note that the builder is ephemeral, it can only be called once for every index.
Note that the builder is ephemeral, it can only be called once for every index.
"""
"""
def
__init__
(
self
,
op_candidates
:
List
[
str
],
C_prev_in
:
int
,
C_in
:
int
,
C
:
int
,
def
__init__
(
self
,
op_candidates
:
List
[
str
],
num_nodes
:
int
,
merge_op
:
Literal
[
'all'
,
'loose_end'
],
C_prev_in
:
nn
.
MaybeChoice
[
int
],
C_in
:
nn
.
MaybeChoice
[
int
],
C
:
nn
.
MaybeChoice
[
int
],
num_nodes
:
int
,
merge_op
:
Literal
[
'all'
,
'loose_end'
],
first_cell_reduce
:
bool
,
last_cell_reduce
:
bool
):
first_cell_reduce
:
bool
,
last_cell_reduce
:
bool
):
self
.
C_prev_in
=
C_prev_in
# This is the out channels of the cell before last cell.
self
.
C_prev_in
=
C_prev_in
# This is the out channels of the cell before last cell.
self
.
C_in
=
C_in
# This is the out channesl of last cell.
self
.
C_in
=
C_in
# This is the out channesl of last cell.
self
.
C
=
C
# This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self
.
C
=
C
# This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self
.
op_candidates
=
op_candidates
self
.
op_candidates
=
op_candidates
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
self
.
merge_op
=
merge_op
self
.
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
merge_op
self
.
first_cell_reduce
=
first_cell_reduce
self
.
first_cell_reduce
=
first_cell_reduce
self
.
last_cell_reduce
=
last_cell_reduce
self
.
last_cell_reduce
=
last_cell_reduce
self
.
_expect_idx
=
0
self
.
_expect_idx
=
0
...
@@ -312,7 +316,7 @@ class CellBuilder:
...
@@ -312,7 +316,7 @@ class CellBuilder:
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor
=
CellPreprocessor
(
self
.
C_prev_in
,
self
.
C_in
,
self
.
C
,
self
.
last_cell_reduce
)
preprocessor
=
CellPreprocessor
(
self
.
C_prev_in
,
self
.
C_in
,
self
.
C
,
self
.
last_cell_reduce
)
ops_factory
:
Dict
[
str
,
Callable
[[
int
,
int
,
int
],
nn
.
Module
]]
=
{
ops_factory
:
Dict
[
str
,
Callable
[[
int
,
int
,
Optional
[
int
]
]
,
nn
.
Module
]]
=
{
op
:
# make final chosen ops named with their aliases
op
:
# make final chosen ops named with their aliases
lambda
node_index
,
op_index
,
input_index
:
lambda
node_index
,
op_index
,
input_index
:
OPS
[
op
](
self
.
C
,
2
if
is_reduction_cell
and
(
OPS
[
op
](
self
.
C
,
2
if
is_reduction_cell
and
(
...
@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
...
@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
class
NDS
(
nn
.
Module
):
class
NDS
(
nn
.
Module
):
"""
__doc__
=
"""
The unified version of NASNet search space.
The unified version of NASNet search space.
We follow the implementation in
We follow the implementation in
...
@@ -378,8 +382,8 @@ class NDS(nn.Module):
...
@@ -378,8 +382,8 @@ class NDS(nn.Module):
op_candidates
:
List
[
str
],
op_candidates
:
List
[
str
],
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
'all'
,
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
'all'
,
num_nodes_per_cell
:
int
=
4
,
num_nodes_per_cell
:
int
=
4
,
width
:
Union
[
Tuple
[
int
],
int
]
=
16
,
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
16
,
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
20
,
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
20
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'imagenet'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'imagenet'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
...
@@ -394,30 +398,31 @@ class NDS(nn.Module):
...
@@ -394,30 +398,31 @@ class NDS(nn.Module):
else
:
else
:
C
=
width
C
=
width
self
.
num_cells
:
nn
.
MaybeChoice
[
int
]
=
cast
(
int
,
num_cells
)
if
isinstance
(
num_cells
,
Iterable
):
if
isinstance
(
num_cells
,
Iterable
):
num_cells
=
nn
.
ValueChoice
(
list
(
num_cells
),
label
=
'depth'
)
self
.
num_cells
=
nn
.
ValueChoice
(
list
(
num_cells
),
label
=
'depth'
)
num_cells_per_stage
=
[
i
*
num_cells
//
3
-
(
i
-
1
)
*
num_cells
//
3
for
i
in
range
(
3
)]
num_cells_per_stage
=
[
i
*
self
.
num_cells
//
3
-
(
i
-
1
)
*
self
.
num_cells
//
3
for
i
in
range
(
3
)]
# auxiliary head is different for network targetted at different datasets
# auxiliary head is different for network targetted at different datasets
if
dataset
==
'imagenet'
:
if
dataset
==
'imagenet'
:
self
.
stem0
=
nn
.
Sequential
(
self
.
stem0
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
C
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
3
,
cast
(
int
,
C
//
2
)
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
//
2
),
nn
.
BatchNorm2d
(
cast
(
int
,
C
//
2
)
)
,
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
C
//
2
,
C
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
cast
(
int
,
C
//
2
),
cast
(
int
,
C
)
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
),
nn
.
BatchNorm2d
(
C
),
)
)
self
.
stem1
=
nn
.
Sequential
(
self
.
stem1
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
C
,
C
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
cast
(
int
,
C
),
cast
(
int
,
C
)
,
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
),
nn
.
BatchNorm2d
(
C
),
)
)
C_pprev
=
C_prev
=
C_curr
=
C
C_pprev
=
C_prev
=
C_curr
=
C
last_cell_reduce
=
True
last_cell_reduce
=
True
elif
dataset
==
'cifar'
:
elif
dataset
==
'cifar'
:
self
.
stem
=
nn
.
Sequential
(
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
3
*
C
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
3
,
cast
(
int
,
3
*
C
)
,
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
3
*
C
)
nn
.
BatchNorm2d
(
cast
(
int
,
3
*
C
)
)
)
)
C_pprev
=
C_prev
=
3
*
C
C_pprev
=
C_prev
=
3
*
C
C_curr
=
C
C_curr
=
C
...
@@ -439,7 +444,7 @@ class NDS(nn.Module):
...
@@ -439,7 +444,7 @@ class NDS(nn.Module):
# C_pprev is output channel number of last second cell among all the cells already built.
# C_pprev is output channel number of last second cell among all the cells already built.
if
len
(
stage
)
>
1
:
if
len
(
stage
)
>
1
:
# Contains more than one cell
# Contains more than one cell
C_pprev
=
len
(
stage
[
-
2
].
output_node_indices
)
*
C_curr
C_pprev
=
len
(
cast
(
nn
.
Cell
,
stage
[
-
2
]
)
.
output_node_indices
)
*
C_curr
else
:
else
:
# Look up in the out channels of last stage.
# Look up in the out channels of last stage.
C_pprev
=
C_prev
C_pprev
=
C_prev
...
@@ -447,7 +452,7 @@ class NDS(nn.Module):
...
@@ -447,7 +452,7 @@ class NDS(nn.Module):
# This was originally,
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
# but due to loose end, it becomes,
C_prev
=
len
(
stage
[
-
1
].
output_node_indices
)
*
C_curr
C_prev
=
len
(
cast
(
nn
.
Cell
,
stage
[
-
1
]
)
.
output_node_indices
)
*
C_curr
# Useful in aligning the pprev and prev cell.
# Useful in aligning the pprev and prev cell.
last_cell_reduce
=
cell_builder
.
last_cell_reduce
last_cell_reduce
=
cell_builder
.
last_cell_reduce
...
@@ -457,11 +462,11 @@ class NDS(nn.Module):
...
@@ -457,11 +462,11 @@ class NDS(nn.Module):
if
auxiliary_loss
:
if
auxiliary_loss
:
assert
isinstance
(
self
.
stages
[
2
],
nn
.
Sequential
),
'Auxiliary loss can only be enabled in retrain mode.'
assert
isinstance
(
self
.
stages
[
2
],
nn
.
Sequential
),
'Auxiliary loss can only be enabled in retrain mode.'
self
.
stages
[
2
]
=
SequentialBreakdown
(
self
.
stages
[
2
])
self
.
stages
[
2
]
=
SequentialBreakdown
(
cast
(
nn
.
Sequential
,
self
.
stages
[
2
])
)
self
.
auxiliary_head
=
AuxiliaryHead
(
C_to_auxiliary
,
self
.
num_labels
,
dataset
=
self
.
dataset
)
self
.
auxiliary_head
=
AuxiliaryHead
(
C_to_auxiliary
,
self
.
num_labels
,
dataset
=
self
.
dataset
)
# type: ignore
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
classifier
=
nn
.
Linear
(
C_prev
,
self
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
cast
(
int
,
C_prev
)
,
self
.
num_labels
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
if
self
.
dataset
==
'imagenet'
:
if
self
.
dataset
==
'imagenet'
:
...
@@ -483,7 +488,7 @@ class NDS(nn.Module):
...
@@ -483,7 +488,7 @@ class NDS(nn.Module):
out
=
self
.
global_pooling
(
s1
)
out
=
self
.
global_pooling
(
s1
)
logits
=
self
.
classifier
(
out
.
view
(
out
.
size
(
0
),
-
1
))
logits
=
self
.
classifier
(
out
.
view
(
out
.
size
(
0
),
-
1
))
if
self
.
training
and
self
.
auxiliary_loss
:
if
self
.
training
and
self
.
auxiliary_loss
:
return
logits
,
logits_aux
return
logits
,
logits_aux
# type: ignore
else
:
else
:
return
logits
return
logits
...
@@ -524,8 +529,8 @@ class NASNet(NDS):
...
@@ -524,8 +529,8 @@ class NASNet(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
NASNET_OPS
,
super
().
__init__
(
self
.
NASNET_OPS
,
...
@@ -555,8 +560,8 @@ class ENAS(NDS):
...
@@ -555,8 +560,8 @@ class ENAS(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
ENAS_OPS
,
super
().
__init__
(
self
.
ENAS_OPS
,
...
@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
...
@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
...
@@ -626,8 +631,8 @@ class PNAS(NDS):
...
@@ -626,8 +631,8 @@ class PNAS(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
PNAS_OPS
,
super
().
__init__
(
self
.
PNAS_OPS
,
...
@@ -660,8 +665,8 @@ class DARTS(NDS):
...
@@ -660,8 +665,8 @@ class DARTS(NDS):
]
]
def
__init__
(
self
,
def
__init__
(
self
,
width
:
Union
[
Tuple
[
int
],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...
],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
):
super
().
__init__
(
self
.
DARTS_OPS
,
super
().
__init__
(
self
.
DARTS_OPS
,
...
...
nni/retiarii/hub/pytorch/proxylessnas.py
View file @
18962129
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
cast
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
...
@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
...
@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
nn
.
MaybeChoice
[
int
]
,
out_channels
:
int
,
out_channels
:
nn
.
MaybeChoice
[
int
]
,
kernel_size
:
int
=
3
,
kernel_size
:
nn
.
MaybeChoice
[
int
]
=
3
,
stride
:
int
=
1
,
stride
:
int
=
1
,
groups
:
int
=
1
,
groups
:
nn
.
MaybeChoice
[
int
]
=
1
,
norm_layer
:
Optional
[
Callable
[
...
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
[
int
]
,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
dilation
:
int
=
1
,
dilation
:
int
=
1
,
)
->
None
:
)
->
None
:
...
@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
...
@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
if
activation_layer
is
None
:
if
activation_layer
is
None
:
activation_layer
=
nn
.
ReLU6
activation_layer
=
nn
.
ReLU6
super
().
__init__
(
super
().
__init__
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
groups
,
nn
.
Conv2d
(
bias
=
False
),
cast
(
int
,
in_channels
),
norm_layer
(
out_channels
),
cast
(
int
,
out_channels
),
cast
(
int
,
kernel_size
),
stride
,
cast
(
int
,
padding
),
dilation
=
dilation
,
groups
=
cast
(
int
,
groups
),
bias
=
False
),
norm_layer
(
cast
(
int
,
out_channels
)),
activation_layer
(
inplace
=
True
)
activation_layer
(
inplace
=
True
)
)
)
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
...
@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
nn
.
MaybeChoice
[
int
]
,
out_channels
:
int
,
out_channels
:
nn
.
MaybeChoice
[
int
]
,
kernel_size
:
int
=
3
,
kernel_size
:
nn
.
MaybeChoice
[
int
]
=
3
,
stride
:
int
=
1
,
stride
:
int
=
1
,
norm_layer
:
Optional
[
Callable
[
...
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
[
int
]
,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
...
@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
...
@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
def
__init__
(
def
__init__
(
self
,
self
,
in_channels
:
int
,
in_channels
:
nn
.
MaybeChoice
[
int
]
,
out_channels
:
int
,
out_channels
:
nn
.
MaybeChoice
[
int
]
,
expand_ratio
:
int
,
expand_ratio
:
nn
.
MaybeChoice
[
float
]
,
kernel_size
:
int
=
3
,
kernel_size
:
nn
.
MaybeChoice
[
int
]
=
3
,
stride
:
int
=
1
,
stride
:
int
=
1
,
squeeze_and_excite
:
Optional
[
Callable
[[
int
],
nn
.
Module
]]
=
None
,
squeeze_and_excite
:
Optional
[
Callable
[[
nn
.
MaybeChoice
[
int
]
]
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
...
,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[
[
int
]
,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
activation_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
...
@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
assert
stride
in
[
1
,
2
]
assert
stride
in
[
1
,
2
]
hidden_ch
=
nn
.
ValueChoice
.
to_int
(
round
(
in_channels
*
expand_ratio
))
hidden_ch
=
nn
.
ValueChoice
.
to_int
(
round
(
cast
(
int
,
in_channels
*
expand_ratio
))
)
# FIXME: check whether this equal works
# FIXME: check whether this equal works
# Residual connection is added here stride = 1 and input channels and output channels are the same.
# Residual connection is added here stride = 1 and input channels and output channels are the same.
...
@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
...
@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
self
.
first_conv
=
ConvBNReLU
(
3
,
widths
[
0
],
stride
=
2
,
norm_layer
=
nn
.
BatchNorm2d
)
self
.
first_conv
=
ConvBNReLU
(
3
,
widths
[
0
],
stride
=
2
,
norm_layer
=
nn
.
BatchNorm2d
)
blocks
=
[
blocks
:
List
[
nn
.
Module
]
=
[
# first stage is fixed
# first stage is fixed
SeparableConv
(
widths
[
0
],
widths
[
1
],
kernel_size
=
3
,
stride
=
1
)
SeparableConv
(
widths
[
0
],
widths
[
1
],
kernel_size
=
3
,
stride
=
1
)
]
]
...
...
nni/retiarii/hub/pytorch/shufflenet.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
cast
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
...
@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
...
@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
int
,
*
,
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
nn
.
MaybeChoice
[
int
]
,
*
,
kernel_size
:
int
,
stride
:
int
,
sequence
:
str
=
"pdp"
,
affine
:
bool
=
True
):
kernel_size
:
int
,
stride
:
int
,
sequence
:
str
=
"pdp"
,
affine
:
bool
=
True
):
super
().
__init__
()
super
().
__init__
()
assert
stride
in
[
1
,
2
]
assert
stride
in
[
1
,
2
]
...
@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
...
@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
def
_decode_point_depth_conv
(
self
,
sequence
):
def
_decode_point_depth_conv
(
self
,
sequence
):
result
=
[]
result
=
[]
first_depth
=
first_point
=
True
first_depth
=
first_point
=
True
pc
=
c
=
self
.
channels
pc
:
int
=
self
.
channels
c
:
int
=
self
.
channels
for
i
,
token
in
enumerate
(
sequence
):
for
i
,
token
in
enumerate
(
sequence
):
# compute output channels of this conv
# compute output channels of this conv
if
i
+
1
==
len
(
sequence
):
if
i
+
1
==
len
(
sequence
):
assert
token
==
"p"
,
"Last conv must be point-wise conv."
assert
token
==
"p"
,
"Last conv must be point-wise conv."
c
=
self
.
oup_main
c
=
self
.
oup_main
elif
token
==
"p"
and
first_point
:
elif
token
==
"p"
and
first_point
:
c
=
self
.
mid_channels
c
=
cast
(
int
,
self
.
mid_channels
)
if
token
==
"d"
:
if
token
==
"d"
:
# depth-wise conv
# depth-wise conv
if
isinstance
(
pc
,
int
)
and
isinstance
(
c
,
int
):
if
isinstance
(
pc
,
int
)
and
isinstance
(
c
,
int
):
...
@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
...
@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
int
,
*
,
stride
:
int
,
affine
:
bool
=
True
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
mid_channels
:
nn
.
MaybeChoice
[
int
]
,
*
,
stride
:
int
,
affine
:
bool
=
True
):
super
().
__init__
(
in_channels
,
out_channels
,
mid_channels
,
super
().
__init__
(
in_channels
,
out_channels
,
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
sequence
=
"dpdpdp"
,
affine
=
affine
)
kernel_size
=
3
,
stride
=
stride
,
sequence
=
"dpdpdp"
,
affine
=
affine
)
...
@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
...
@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
)
)
self
.
features
=
[]
feature
_block
s
=
[]
global_block_idx
=
0
global_block_idx
=
0
for
stage_idx
,
num_repeat
in
enumerate
(
self
.
stage_repeats
):
for
stage_idx
,
num_repeat
in
enumerate
(
self
.
stage_repeats
):
...
@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
...
@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
else
:
else
:
mid_channels
=
int
(
base_mid_channels
)
mid_channels
=
int
(
base_mid_channels
)
mid_channels
=
cast
(
nn
.
MaybeChoice
[
int
],
mid_channels
)
choice_block
=
nn
.
LayerChoice
([
choice_block
=
nn
.
LayerChoice
([
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
5
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
5
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
7
,
stride
=
stride
,
affine
=
affine
),
ShuffleNetBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
kernel_size
=
7
,
stride
=
stride
,
affine
=
affine
),
ShuffleXceptionBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
stride
=
stride
,
affine
=
affine
)
ShuffleXceptionBlock
(
in_channels
,
out_channels
,
mid_channels
=
mid_channels
,
stride
=
stride
,
affine
=
affine
)
],
label
=
f
'layer_
{
global_block_idx
}
'
)
],
label
=
f
'layer_
{
global_block_idx
}
'
)
self
.
features
.
append
(
choice_block
)
feature
_block
s
.
append
(
choice_block
)
self
.
features
=
nn
.
Sequential
(
*
self
.
features
)
self
.
features
=
nn
.
Sequential
(
*
feature
_block
s
)
# final layers
# final layers
last_conv_channels
=
self
.
stage_out_channels
[
-
1
]
last_conv_channels
=
self
.
stage_out_channels
[
-
1
]
...
@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
...
@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
if
m
.
running_mean
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
if
m
.
weight
is
not
None
:
if
m
.
weight
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
if
m
.
running_mean
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
):
elif
isinstance
(
m
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
...
...
nni/retiarii/hub/pytorch/utils.py
0 → 100644
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
nni/retiarii/integration.py
View file @
18962129
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
logging
import
logging
import
os
import
os
from
typing
import
Any
,
Callable
from
typing
import
Any
,
Callable
,
Optional
import
nni
import
nni
from
nni.common.serializer
import
PayloadTooLarge
from
nni.common.serializer
import
PayloadTooLarge
...
@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
self
.
search_space
=
None
self
.
search_space
=
None
self
.
send_trial_callback
:
Callable
[[
dict
],
None
]
=
None
self
.
send_trial_callback
:
Optional
[
Callable
[[
dict
],
None
]
]
=
None
self
.
request_trial_jobs_callback
:
Callable
[[
int
],
None
]
=
None
self
.
request_trial_jobs_callback
:
Optional
[
Callable
[[
int
],
None
]
]
=
None
self
.
trial_end_callback
:
Callable
[[
int
,
bool
],
None
]
=
None
self
.
trial_end_callback
:
Optional
[
Callable
[[
int
,
bool
],
None
]
]
=
None
self
.
intermediate_metric_callback
:
Callable
[[
int
,
MetricData
],
None
]
=
None
self
.
intermediate_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]
]
=
None
self
.
final_metric_callback
:
Callable
[[
int
,
MetricData
],
None
]
=
None
self
.
final_metric_callback
:
Optional
[
Callable
[[
int
,
MetricData
],
None
]
]
=
None
self
.
parameters_count
=
0
self
.
parameters_count
=
0
...
@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
_logger
.
debug
(
'Trial end: %s'
,
data
)
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
trial_end_callback
(
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
if
self
.
trial_end_callback
is
not
None
:
data
[
'event'
]
==
'SUCCEEDED'
)
self
.
trial_end_callback
(
nni
.
load
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
_logger
.
debug
(
'Metric reported: %s'
,
data
)
_logger
.
debug
(
'Metric reported: %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
raise
ValueError
(
'Request parameter not supported'
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
self
.
intermediate_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
if
self
.
intermediate_metric_callback
is
not
None
:
self
.
_process_value
(
data
[
'value'
]))
self
.
intermediate_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
self
.
_process_value
(
data
[
'value'
]))
elif
data
[
'type'
]
==
MetricType
.
FINAL
:
elif
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
final_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
if
self
.
final_metric_callback
is
not
None
:
self
.
_process_value
(
data
[
'value'
]))
self
.
final_metric_callback
(
data
[
'parameter_id'
],
# pylint: disable=not-callable
self
.
_process_value
(
data
[
'value'
]))
@
staticmethod
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
def
_process_value
(
value
)
->
Any
:
# hopefully a float
...
...
nni/retiarii/mutator.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
,
Tuple
)
import
warnings
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
)
from
.graph
import
Model
,
Mutation
,
ModelStatus
from
.graph
import
Model
,
Mutation
,
ModelStatus
...
@@ -44,9 +45,11 @@ class Mutator:
...
@@ -44,9 +45,11 @@ class Mutator:
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
"""
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
,
label
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
,
label
:
str
=
cast
(
str
,
None
)
)
:
self
.
sampler
:
Optional
[
Sampler
]
=
sampler
self
.
sampler
:
Optional
[
Sampler
]
=
sampler
self
.
label
:
Optional
[
str
]
=
label
if
label
is
None
:
warnings
.
warn
(
'Each mutator should have an explicit label. Mutator without label is deprecated.'
,
DeprecationWarning
)
self
.
label
:
str
=
label
self
.
_cur_model
:
Optional
[
Model
]
=
None
self
.
_cur_model
:
Optional
[
Model
]
=
None
self
.
_cur_choice_idx
:
Optional
[
int
]
=
None
self
.
_cur_choice_idx
:
Optional
[
int
]
=
None
...
...
nni/retiarii/nn/pytorch/api.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
math
import
itertools
import
itertools
import
math
import
operator
import
operator
import
warnings
import
warnings
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
,
Sequence
from
typing
import
(
Any
,
Callable
,
Dict
,
Generic
,
Iterable
,
Iterator
,
List
,
NoReturn
,
Optional
,
Sequence
,
SupportsRound
,
TypeVar
,
Union
,
cast
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
Translatable
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
,
ModelNamespace
,
NoContextError
from
nni.retiarii.utils
import
(
STATE_DICT_PY_MAPPING_PARTIAL
,
ModelNamespace
,
NoContextError
)
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
__all__
=
[
# APIs
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'ModelParameterChoice'
,
'Placeholder'
,
# Fixed module
'ChosenInputs'
,
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'ModelParameterChoice'
,
'Placeholder'
,
'ChosenInputs'
]
# Type utils
'ReductionType'
,
'MaybeChoice'
,
'ChoiceOf'
,
]
class
LayerChoice
(
Mutable
):
class
LayerChoice
(
Mutable
):
...
@@ -130,26 +147,16 @@ class LayerChoice(Mutable):
...
@@ -130,26 +147,16 @@ class LayerChoice(Mutable):
self
.
names
.
append
(
str
(
i
))
self
.
names
.
append
(
str
(
i
))
else
:
else
:
raise
TypeError
(
"Unsupported candidates type: {}"
.
format
(
type
(
candidates
)))
raise
TypeError
(
"Unsupported candidates type: {}"
.
format
(
type
(
candidates
)))
self
.
_first_module
=
self
.
_modules
[
self
.
names
[
0
]]
# to make the dummy forward meaningful
self
.
_first_module
=
cast
(
nn
.
Module
,
self
.
_modules
[
self
.
names
[
0
]])
# to make the dummy forward meaningful
@
property
def
key
(
self
):
return
self
.
_key
()
@
torch
.
jit
.
ignore
def
_key
(
self
):
warnings
.
warn
(
'Using key to access the identifier of LayerChoice is deprecated. Please use label instead.'
,
category
=
DeprecationWarning
)
return
self
.
_label
@
property
@
property
def
label
(
self
):
def
label
(
self
):
return
self
.
_label
return
self
.
_label
def
__getitem__
(
self
,
idx
)
:
def
__getitem__
(
self
,
idx
:
Union
[
int
,
str
])
->
nn
.
Module
:
if
isinstance
(
idx
,
str
):
if
isinstance
(
idx
,
str
):
return
self
.
_modules
[
idx
]
return
cast
(
nn
.
Module
,
self
.
_modules
[
idx
]
)
return
list
(
self
)[
idx
]
return
cast
(
nn
.
Module
,
list
(
self
)[
idx
]
)
def
__setitem__
(
self
,
idx
,
module
):
def
__setitem__
(
self
,
idx
,
module
):
key
=
idx
if
isinstance
(
idx
,
str
)
else
self
.
names
[
idx
]
key
=
idx
if
isinstance
(
idx
,
str
)
else
self
.
names
[
idx
]
...
@@ -173,15 +180,6 @@ class LayerChoice(Mutable):
...
@@ -173,15 +180,6 @@ class LayerChoice(Mutable):
def
__iter__
(
self
):
def
__iter__
(
self
):
return
map
(
lambda
name
:
self
.
_modules
[
name
],
self
.
names
)
return
map
(
lambda
name
:
self
.
_modules
[
name
],
self
.
names
)
@
property
def
choices
(
self
):
return
self
.
_choices
()
@
torch
.
jit
.
ignore
def
_choices
(
self
):
warnings
.
warn
(
"layer_choice.choices is deprecated. Use `list(layer_choice)` instead."
,
category
=
DeprecationWarning
)
return
list
(
self
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
The forward of layer choice is simply running the first candidate module.
The forward of layer choice is simply running the first candidate module.
...
@@ -266,16 +264,6 @@ class InputChoice(Mutable):
...
@@ -266,16 +264,6 @@ class InputChoice(Mutable):
assert
self
.
reduction
in
[
'mean'
,
'concat'
,
'sum'
,
'none'
]
assert
self
.
reduction
in
[
'mean'
,
'concat'
,
'sum'
,
'none'
]
self
.
_label
=
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
@
property
def
key
(
self
):
return
self
.
_key
()
@
torch
.
jit
.
ignore
def
_key
(
self
):
warnings
.
warn
(
'Using key to access the identifier of InputChoice is deprecated. Please use label instead.'
,
category
=
DeprecationWarning
)
return
self
.
_label
@
property
@
property
def
label
(
self
):
def
label
(
self
):
return
self
.
_label
return
self
.
_label
...
@@ -350,7 +338,7 @@ def _valuechoice_codegen(*, _internal: bool = False):
...
@@ -350,7 +338,7 @@ def _valuechoice_codegen(*, _internal: bool = False):
'truediv'
:
'//'
,
'floordiv'
:
'/'
,
'mod'
:
'%'
,
'truediv'
:
'//'
,
'floordiv'
:
'/'
,
'mod'
:
'%'
,
'lshift'
:
'<<'
,
'rshift'
:
'>>'
,
'lshift'
:
'<<'
,
'rshift'
:
'>>'
,
'and'
:
'&'
,
'xor'
:
'^'
,
'or'
:
'|'
,
'and'
:
'&'
,
'xor'
:
'^'
,
'or'
:
'|'
,
# no re
flection
# no re
verse
'lt'
:
'<'
,
'le'
:
'<='
,
'eq'
:
'=='
,
'lt'
:
'<'
,
'le'
:
'<='
,
'eq'
:
'=='
,
'ne'
:
'!='
,
'ge'
:
'>='
,
'gt'
:
'>'
,
'ne'
:
'!='
,
'ge'
:
'>='
,
'gt'
:
'>'
,
# NOTE
# NOTE
...
@@ -358,14 +346,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
...
@@ -358,14 +346,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
# Might support them in future when we actually need them.
# Might support them in future when we actually need them.
}
}
binary_template
=
""" def __{op}__(self
, other:
Any) -> '
Value
Choice
X
':
binary_template
=
""" def __{op}__(self
: 'ChoiceOf[Any]', other: 'MaybeChoice[
Any
]'
) -> 'Choice
Of[Any]
':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
binary_r_template
=
""" def __r{op}__(self
, other:
Any) -> '
Value
Choice
X
':
binary_r_template
=
""" def __r{op}__(self
: 'ChoiceOf[Any]', other: 'MaybeChoice[
Any
]'
) -> 'Choice
Of[Any]
':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
unary_template
=
""" def __{op}__(self
) -> 'ValueChoiceX
':
unary_template
=
""" def __{op}__(self
: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]
':
return ValueChoiceX(operator.{op}, '{sym}{{}}', [self])"""
return
cast(ChoiceOf[_value],
ValueChoiceX(operator.{op}, '{sym}{{}}', [self])
)
"""
for
op
,
sym
in
MAPPING
.
items
():
for
op
,
sym
in
MAPPING
.
items
():
if
op
in
[
'neg'
,
'pos'
,
'invert'
]:
if
op
in
[
'neg'
,
'pos'
,
'invert'
]:
...
@@ -377,8 +365,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
...
@@ -377,8 +365,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
print
(
binary_r_template
.
format
(
op
=
op
,
opt
=
opt
,
sym
=
sym
)
+
'
\n
'
)
print
(
binary_r_template
.
format
(
op
=
op
,
opt
=
opt
,
sym
=
sym
)
+
'
\n
'
)
def
_valuechoice_staticmethod_helper
(
orig_func
):
_func
=
TypeVar
(
'_func'
)
orig_func
.
__doc__
+=
"""
_cand
=
TypeVar
(
'_cand'
)
_value
=
TypeVar
(
'_value'
)
def
_valuechoice_staticmethod_helper
(
orig_func
:
_func
)
->
_func
:
if
orig_func
.
__doc__
is
not
None
:
orig_func
.
__doc__
+=
"""
Notes
Notes
-----
-----
This function performs lazy evaluation.
This function performs lazy evaluation.
...
@@ -388,7 +382,7 @@ def _valuechoice_staticmethod_helper(orig_func):
...
@@ -388,7 +382,7 @@ def _valuechoice_staticmethod_helper(orig_func):
return
orig_func
return
orig_func
class
ValueChoiceX
(
Translatable
,
nn
.
Module
):
class
ValueChoiceX
(
Generic
[
_cand
],
Translatable
,
nn
.
Module
):
"""Internal API. Implementation note:
"""Internal API. Implementation note:
The transformed (X) version of value choice.
The transformed (X) version of value choice.
...
@@ -408,7 +402,10 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -408,7 +402,10 @@ class ValueChoiceX(Translatable, nn.Module):
This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
"""
"""
def
__init__
(
self
,
function
:
Callable
[...,
Any
],
repr_template
:
str
,
arguments
:
List
[
Any
],
dry_run
:
bool
=
True
):
def
__init__
(
self
,
function
:
Callable
[...,
_cand
]
=
cast
(
Callable
[...,
_cand
],
None
),
repr_template
:
str
=
cast
(
str
,
None
),
arguments
:
List
[
Any
]
=
cast
(
'List[MaybeChoice[_cand]]'
,
None
),
dry_run
:
bool
=
True
):
super
().
__init__
()
super
().
__init__
()
if
function
is
None
:
if
function
is
None
:
...
@@ -431,7 +428,7 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -431,7 +428,7 @@ class ValueChoiceX(Translatable, nn.Module):
def
inner_choices
(
self
)
->
Iterable
[
'ValueChoice'
]:
def
inner_choices
(
self
)
->
Iterable
[
'ValueChoice'
]:
"""
"""
Return a
n iterable
of all leaf value choices.
Return a
generator
of all leaf value choices.
Useful for composition of value choices.
Useful for composition of value choices.
No deduplication on labels. Mutators should take care.
No deduplication on labels. Mutators should take care.
"""
"""
...
@@ -439,18 +436,18 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -439,18 +436,18 @@ class ValueChoiceX(Translatable, nn.Module):
if
isinstance
(
arg
,
ValueChoiceX
):
if
isinstance
(
arg
,
ValueChoiceX
):
yield
from
arg
.
inner_choices
()
yield
from
arg
.
inner_choices
()
def
dry_run
(
self
)
->
Any
:
def
dry_run
(
self
)
->
_cand
:
"""
"""
Dry run the value choice to get one of its possible evaluation results.
Dry run the value choice to get one of its possible evaluation results.
"""
"""
# values are not used
# values are not used
return
self
.
_evaluate
(
iter
([]),
True
)
return
self
.
_evaluate
(
iter
([]),
True
)
def
all_options
(
self
)
->
Iterable
[
Any
]:
def
all_options
(
self
)
->
Iterable
[
_cand
]:
"""Explore all possibilities of a value choice.
"""Explore all possibilities of a value choice.
"""
"""
# Record all inner choices: label -> candidates, no duplicates.
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices
:
Dict
[
str
,
List
[
Any
]]
=
{}
dedup_inner_choices
:
Dict
[
str
,
List
[
_cand
]]
=
{}
# All labels of leaf nodes on tree, possibly duplicates.
# All labels of leaf nodes on tree, possibly duplicates.
all_labels
:
List
[
str
]
=
[]
all_labels
:
List
[
str
]
=
[]
...
@@ -470,14 +467,14 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -470,14 +467,14 @@ class ValueChoiceX(Translatable, nn.Module):
chosen
=
dict
(
zip
(
dedup_labels
,
chosen
))
chosen
=
dict
(
zip
(
dedup_labels
,
chosen
))
yield
self
.
evaluate
([
chosen
[
label
]
for
label
in
all_labels
])
yield
self
.
evaluate
([
chosen
[
label
]
for
label
in
all_labels
])
def
evaluate
(
self
,
values
:
Iterable
[
Any
])
->
Any
:
def
evaluate
(
self
,
values
:
Iterable
[
_cand
])
->
_cand
:
"""
"""
Evaluate the result of this group.
Evaluate the result of this group.
``values`` should in the same order of ``inner_choices()``.
``values`` should in the same order of ``inner_choices()``.
"""
"""
return
self
.
_evaluate
(
iter
(
values
),
False
)
return
self
.
_evaluate
(
iter
(
values
),
False
)
def
_evaluate
(
self
,
values
:
Itera
ble
[
Any
],
dry_run
:
bool
=
False
)
->
Any
:
def
_evaluate
(
self
,
values
:
Itera
tor
[
_cand
],
dry_run
:
bool
=
False
)
->
_cand
:
# "values" iterates in the recursion
# "values" iterates in the recursion
eval_args
=
[]
eval_args
=
[]
for
arg
in
self
.
arguments
:
for
arg
in
self
.
arguments
:
...
@@ -497,7 +494,7 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -497,7 +494,7 @@ class ValueChoiceX(Translatable, nn.Module):
"""
"""
return
self
.
dry_run
()
return
self
.
dry_run
()
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
reprs
=
[]
reprs
=
[]
for
arg
in
self
.
arguments
:
for
arg
in
self
.
arguments
:
if
isinstance
(
arg
,
ValueChoiceX
)
and
not
isinstance
(
arg
,
ValueChoice
):
if
isinstance
(
arg
,
ValueChoiceX
)
and
not
isinstance
(
arg
,
ValueChoice
):
...
@@ -513,7 +510,7 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -513,7 +510,7 @@ class ValueChoiceX(Translatable, nn.Module):
# Special operators that can be useful in place of built-in conditional operators.
# Special operators that can be useful in place of built-in conditional operators.
@
staticmethod
@
staticmethod
@
_valuechoice_staticmethod_helper
@
_valuechoice_staticmethod_helper
def
to_int
(
obj
:
'
Valu
eChoice
Or
Any'
)
->
Union
[
'Valu
eChoice
X'
,
int
]:
def
to_int
(
obj
:
'
Mayb
eChoice
[
Any
]
'
)
->
'Mayb
eChoice
[
int]
'
:
"""
"""
Convert a ``ValueChoice`` to an integer.
Convert a ``ValueChoice`` to an integer.
"""
"""
...
@@ -523,7 +520,7 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -523,7 +520,7 @@ class ValueChoiceX(Translatable, nn.Module):
@
staticmethod
@
staticmethod
@
_valuechoice_staticmethod_helper
@
_valuechoice_staticmethod_helper
def
to_float
(
obj
:
'
Valu
eChoice
Or
Any'
)
->
Union
[
'Valu
eChoice
X'
,
float
]:
def
to_float
(
obj
:
'
Mayb
eChoice
[
Any
]
'
)
->
'Mayb
eChoice
[
float]
'
:
"""
"""
Convert a ``ValueChoice`` to a float.
Convert a ``ValueChoice`` to a float.
"""
"""
...
@@ -533,9 +530,9 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -533,9 +530,9 @@ class ValueChoiceX(Translatable, nn.Module):
@
staticmethod
@
staticmethod
@
_valuechoice_staticmethod_helper
@
_valuechoice_staticmethod_helper
def
condition
(
pred
:
'
Valu
eChoice
OrAny
'
,
def
condition
(
pred
:
'
Mayb
eChoice
[bool]
'
,
true
:
'
Valu
eChoice
OrAny
'
,
true
:
'
Mayb
eChoice
[_value]
'
,
false
:
'
Valu
eChoice
OrAny
'
)
->
'
Valu
eChoice
OrAny
'
:
false
:
'
Mayb
eChoice
[_value]
'
)
->
'
Mayb
eChoice
[_value]
'
:
"""
"""
Return ``true`` if the predicate ``pred`` is true else ``false``.
Return ``true`` if the predicate ``pred`` is true else ``false``.
...
@@ -549,35 +546,39 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -549,35 +546,39 @@ class ValueChoiceX(Translatable, nn.Module):
@
staticmethod
@
staticmethod
@
_valuechoice_staticmethod_helper
@
_valuechoice_staticmethod_helper
def
max
(
arg0
:
Union
[
Iterable
[
'
Valu
eChoice
OrAny
'
],
'
Valu
eChoice
OrAny
'
],
def
max
(
arg0
:
Union
[
Iterable
[
'
Mayb
eChoice
[_value]
'
],
'
Mayb
eChoice
[_value]
'
],
*
args
:
List
[
'ValueChoiceOrAny'
]
)
->
'
Valu
eChoice
OrAny
'
:
*
args
:
'MaybeChoice[_value]'
)
->
'
Mayb
eChoice
[_value]
'
:
"""
"""
Returns the maximum value from a list of value choices.
Returns the maximum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
where the parameters could be an iterable, or at least two arguments.
"""
"""
if
not
args
:
if
not
args
:
return
ValueChoiceX
.
max
(
*
list
(
arg0
))
if
not
isinstance
(
arg0
,
Iterable
):
lst
=
[
arg0
]
+
list
(
args
)
raise
TypeError
(
'Expect more than one items to compare max'
)
return
cast
(
MaybeChoice
[
_value
],
ValueChoiceX
.
max
(
*
list
(
arg0
)))
lst
=
list
(
arg0
)
if
isinstance
(
arg0
,
Iterable
)
else
[
arg0
]
+
list
(
args
)
if
any
(
isinstance
(
obj
,
ValueChoiceX
)
for
obj
in
lst
):
if
any
(
isinstance
(
obj
,
ValueChoiceX
)
for
obj
in
lst
):
return
ValueChoiceX
(
max
,
'max({})'
,
lst
)
return
ValueChoiceX
(
max
,
'max({})'
,
lst
)
return
max
(
lst
)
return
max
(
cast
(
Any
,
lst
)
)
@
staticmethod
@
staticmethod
@
_valuechoice_staticmethod_helper
@
_valuechoice_staticmethod_helper
def
min
(
arg0
:
Union
[
Iterable
[
'
Valu
eChoice
OrAny
'
],
'
Valu
eChoice
OrAny
'
],
def
min
(
arg0
:
Union
[
Iterable
[
'
Mayb
eChoice
[_value]
'
],
'
Mayb
eChoice
[_value]
'
],
*
args
:
List
[
'ValueChoiceOrAny'
]
)
->
'
Valu
eChoice
OrAny
'
:
*
args
:
'MaybeChoice[_value]'
)
->
'
Mayb
eChoice
[_value]
'
:
"""
"""
Returns the minunum value from a list of value choices.
Returns the minunum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
where the parameters could be an iterable, or at least two arguments.
"""
"""
if
not
args
:
if
not
args
:
return
ValueChoiceX
.
min
(
*
list
(
arg0
))
if
not
isinstance
(
arg0
,
Iterable
):
lst
=
[
arg0
]
+
list
(
args
)
raise
TypeError
(
'Expect more than one items to compare min'
)
return
cast
(
MaybeChoice
[
_value
],
ValueChoiceX
.
min
(
*
list
(
arg0
)))
lst
=
list
(
arg0
)
if
isinstance
(
arg0
,
Iterable
)
else
[
arg0
]
+
list
(
args
)
if
any
(
isinstance
(
obj
,
ValueChoiceX
)
for
obj
in
lst
):
if
any
(
isinstance
(
obj
,
ValueChoiceX
)
for
obj
in
lst
):
return
ValueChoiceX
(
min
,
'min({})'
,
lst
)
return
ValueChoiceX
(
min
,
'min({})'
,
lst
)
return
min
(
lst
)
return
min
(
cast
(
Any
,
lst
)
)
def
__hash__
(
self
):
def
__hash__
(
self
):
# this is required because we have implemented ``__eq__``
# this is required because we have implemented ``__eq__``
...
@@ -589,24 +590,25 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -589,24 +590,25 @@ class ValueChoiceX(Translatable, nn.Module):
# - Implementation effort is too huge.
# - Implementation effort is too huge.
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
def
__getitem__
(
self
,
key
:
Any
)
->
'
Value
Choice
X
'
:
def
__getitem__
(
self
:
'ChoiceOf[Any]'
,
key
:
Any
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
lambda
x
,
y
:
x
[
y
],
'{}[{}]'
,
[
self
,
key
])
return
ValueChoiceX
(
lambda
x
,
y
:
x
[
y
],
'{}[{}]'
,
[
self
,
key
])
# region implement int, float, round, trunc, floor, ceil
# region implement int, float, round, trunc, floor, ceil
# because I believe sometimes we need them to calculate #channels
# because I believe sometimes we need them to calculate #channels
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
def
__round__
(
self
,
ndigits
:
Optional
[
Any
]
=
None
)
->
'ValueChoiceX'
:
def
__round__
(
self
:
'ChoiceOf[SupportsRound[_value]]'
,
ndigits
:
Optional
[
'MaybeChoice[int]'
]
=
None
)
->
'ChoiceOf[Union[int, SupportsRound[_value]]]'
:
if
ndigits
is
not
None
:
if
ndigits
is
not
None
:
return
ValueChoiceX
(
round
,
'round({}, {})'
,
[
self
,
ndigits
])
return
cast
(
ChoiceOf
[
Union
[
int
,
SupportsRound
[
_value
]]],
ValueChoiceX
(
round
,
'round({}, {})'
,
[
self
,
ndigits
])
)
return
ValueChoiceX
(
round
,
'round({})'
,
[
self
])
return
cast
(
ChoiceOf
[
Union
[
int
,
SupportsRound
[
_value
]]],
ValueChoiceX
(
round
,
'round({})'
,
[
self
])
)
def
__trunc__
(
self
)
->
'ValueChoiceX'
:
def
__trunc__
(
self
)
->
NoReturn
:
raise
RuntimeError
(
"Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices."
)
raise
RuntimeError
(
"Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices."
)
def
__floor__
(
self
)
->
'
Value
Choice
X
'
:
def
__floor__
(
self
:
'ChoiceOf[Any]'
)
->
'Choice
Of[int]
'
:
return
ValueChoiceX
(
math
.
floor
,
'math.floor({})'
,
[
self
])
return
ValueChoiceX
(
math
.
floor
,
'math.floor({})'
,
[
self
])
def
__ceil__
(
self
)
->
'
Value
Choice
X
'
:
def
__ceil__
(
self
:
'ChoiceOf[Any]'
)
->
'Choice
Of[int]
'
:
return
ValueChoiceX
(
math
.
ceil
,
'math.ceil({})'
,
[
self
])
return
ValueChoiceX
(
math
.
ceil
,
'math.ceil({})'
,
[
self
])
def
__index__
(
self
)
->
NoReturn
:
def
__index__
(
self
)
->
NoReturn
:
...
@@ -622,132 +624,133 @@ class ValueChoiceX(Translatable, nn.Module):
...
@@ -622,132 +624,133 @@ class ValueChoiceX(Translatable, nn.Module):
# region the following code is generated with codegen (see above)
# region the following code is generated with codegen (see above)
# Annotated with "region" because I want to collapse them in vscode
# Annotated with "region" because I want to collapse them in vscode
def
__neg__
(
self
)
->
'ValueChoiceX
'
:
def
__neg__
(
self
:
'ChoiceOf[_value]'
)
->
'ChoiceOf[_value]
'
:
return
ValueChoiceX
(
operator
.
neg
,
'-{}'
,
[
self
])
return
cast
(
ChoiceOf
[
_value
],
ValueChoiceX
(
operator
.
neg
,
'-{}'
,
[
self
])
)
def
__pos__
(
self
)
->
'ValueChoiceX
'
:
def
__pos__
(
self
:
'ChoiceOf[_value]'
)
->
'ChoiceOf[_value]
'
:
return
ValueChoiceX
(
operator
.
pos
,
'+{}'
,
[
self
])
return
cast
(
ChoiceOf
[
_value
],
ValueChoiceX
(
operator
.
pos
,
'+{}'
,
[
self
])
)
def
__invert__
(
self
)
->
'ValueChoiceX
'
:
def
__invert__
(
self
:
'ChoiceOf[_value]'
)
->
'ChoiceOf[_value]
'
:
return
ValueChoiceX
(
operator
.
invert
,
'~{}'
,
[
self
])
return
cast
(
ChoiceOf
[
_value
],
ValueChoiceX
(
operator
.
invert
,
'~{}'
,
[
self
])
)
def
__add__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__add__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
add
,
'{} + {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
add
,
'{} + {}'
,
[
self
,
other
])
def
__radd__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__radd__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
add
,
'{} + {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
add
,
'{} + {}'
,
[
other
,
self
])
def
__sub__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__sub__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
sub
,
'{} - {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
sub
,
'{} - {}'
,
[
self
,
other
])
def
__rsub__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rsub__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
sub
,
'{} - {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
sub
,
'{} - {}'
,
[
other
,
self
])
def
__mul__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__mul__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
mul
,
'{} * {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
mul
,
'{} * {}'
,
[
self
,
other
])
def
__rmul__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rmul__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
mul
,
'{} * {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
mul
,
'{} * {}'
,
[
other
,
self
])
def
__matmul__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__matmul__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
matmul
,
'{} @ {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
matmul
,
'{} @ {}'
,
[
self
,
other
])
def
__rmatmul__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rmatmul__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
matmul
,
'{} @ {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
matmul
,
'{} @ {}'
,
[
other
,
self
])
def
__truediv__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__truediv__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
truediv
,
'{} // {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
truediv
,
'{} // {}'
,
[
self
,
other
])
def
__rtruediv__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rtruediv__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
truediv
,
'{} // {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
truediv
,
'{} // {}'
,
[
other
,
self
])
def
__floordiv__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__floordiv__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
floordiv
,
'{} / {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
floordiv
,
'{} / {}'
,
[
self
,
other
])
def
__rfloordiv__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rfloordiv__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
floordiv
,
'{} / {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
floordiv
,
'{} / {}'
,
[
other
,
self
])
def
__mod__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__mod__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
mod
,
'{} % {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
mod
,
'{} % {}'
,
[
self
,
other
])
def
__rmod__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rmod__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
mod
,
'{} % {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
mod
,
'{} % {}'
,
[
other
,
self
])
def
__lshift__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__lshift__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
lshift
,
'{} << {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
lshift
,
'{} << {}'
,
[
self
,
other
])
def
__rlshift__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rlshift__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
lshift
,
'{} << {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
lshift
,
'{} << {}'
,
[
other
,
self
])
def
__rshift__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rshift__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
rshift
,
'{} >> {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
rshift
,
'{} >> {}'
,
[
self
,
other
])
def
__rrshift__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rrshift__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
rshift
,
'{} >> {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
rshift
,
'{} >> {}'
,
[
other
,
self
])
def
__and__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__and__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
and_
,
'{} & {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
and_
,
'{} & {}'
,
[
self
,
other
])
def
__rand__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rand__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
and_
,
'{} & {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
and_
,
'{} & {}'
,
[
other
,
self
])
def
__xor__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__xor__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
xor
,
'{} ^ {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
xor
,
'{} ^ {}'
,
[
self
,
other
])
def
__rxor__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rxor__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
xor
,
'{} ^ {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
xor
,
'{} ^ {}'
,
[
other
,
self
])
def
__or__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__or__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
or_
,
'{} | {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
or_
,
'{} | {}'
,
[
self
,
other
])
def
__ror__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__ror__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
or_
,
'{} | {}'
,
[
other
,
self
])
return
ValueChoiceX
(
operator
.
or_
,
'{} | {}'
,
[
other
,
self
])
def
__lt__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__lt__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
lt
,
'{} < {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
lt
,
'{} < {}'
,
[
self
,
other
])
def
__le__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__le__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
le
,
'{} <= {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
le
,
'{} <= {}'
,
[
self
,
other
])
def
__eq__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__eq__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
eq
,
'{} == {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
eq
,
'{} == {}'
,
[
self
,
other
])
def
__ne__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__ne__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
ne
,
'{} != {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
ne
,
'{} != {}'
,
[
self
,
other
])
def
__ge__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__ge__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
ge
,
'{} >= {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
ge
,
'{} >= {}'
,
[
self
,
other
])
def
__gt__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__gt__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
operator
.
gt
,
'{} > {}'
,
[
self
,
other
])
return
ValueChoiceX
(
operator
.
gt
,
'{} > {}'
,
[
self
,
other
])
# endregion
# endregion
# __pow__, __divmod__, __abs__ are special ones.
# __pow__, __divmod__, __abs__ are special ones.
# Not easy to cover those cases with codegen.
# Not easy to cover those cases with codegen.
def
__pow__
(
self
,
other
:
Any
,
modulo
:
Optional
[
Any
]
=
None
)
->
'
Value
Choice
X
'
:
def
__pow__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
,
modulo
:
Optional
[
'MaybeChoice[
Any]
'
]
=
None
)
->
'Choice
Of[Any]
'
:
if
modulo
is
not
None
:
if
modulo
is
not
None
:
return
ValueChoiceX
(
pow
,
'pow({}, {}, {})'
,
[
self
,
other
,
modulo
])
return
ValueChoiceX
(
pow
,
'pow({}, {}, {})'
,
[
self
,
other
,
modulo
])
return
ValueChoiceX
(
lambda
a
,
b
:
a
**
b
,
'{} ** {}'
,
[
self
,
other
])
return
ValueChoiceX
(
lambda
a
,
b
:
a
**
b
,
'{} ** {}'
,
[
self
,
other
])
def
__rpow__
(
self
,
other
:
Any
,
modulo
:
Optional
[
Any
]
=
None
)
->
'
Value
Choice
X
'
:
def
__rpow__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
,
modulo
:
Optional
[
'MaybeChoice[
Any]
'
]
=
None
)
->
'Choice
Of[Any]
'
:
if
modulo
is
not
None
:
if
modulo
is
not
None
:
return
ValueChoiceX
(
pow
,
'pow({}, {}, {})'
,
[
other
,
self
,
modulo
])
return
ValueChoiceX
(
pow
,
'pow({}, {}, {})'
,
[
other
,
self
,
modulo
])
return
ValueChoiceX
(
lambda
a
,
b
:
a
**
b
,
'{} ** {}'
,
[
other
,
self
])
return
ValueChoiceX
(
lambda
a
,
b
:
a
**
b
,
'{} ** {}'
,
[
other
,
self
])
def
__divmod__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__divmod__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
divmod
,
'divmod({}, {})'
,
[
self
,
other
])
return
ValueChoiceX
(
divmod
,
'divmod({}, {})'
,
[
self
,
other
])
def
__rdivmod__
(
self
,
other
:
Any
)
->
'
Value
Choice
X
'
:
def
__rdivmod__
(
self
:
'ChoiceOf[Any]'
,
other
:
'MaybeChoice[
Any
]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
divmod
,
'divmod({}, {})'
,
[
other
,
self
])
return
ValueChoiceX
(
divmod
,
'divmod({}, {})'
,
[
other
,
self
])
def
__abs__
(
self
)
->
'
Value
Choice
X
'
:
def
__abs__
(
self
:
'ChoiceOf[Any]'
)
->
'Choice
Of[Any]
'
:
return
ValueChoiceX
(
abs
,
'abs({})'
,
[
self
])
return
ValueChoiceX
(
abs
,
'abs({})'
,
[
self
])
ValueChoiceOrAny
=
TypeVar
(
'ValueChoiceOrAny'
,
ValueChoiceX
,
Any
)
ChoiceOf
=
ValueChoiceX
MaybeChoice
=
Union
[
ValueChoiceX
[
_cand
],
_cand
]
class
ValueChoice
(
ValueChoiceX
,
Mutable
):
class
ValueChoice
(
ValueChoiceX
[
_cand
]
,
Mutable
):
"""
"""
ValueChoice is to choose one from ``candidates``. The most common use cases are:
ValueChoice is to choose one from ``candidates``. The most common use cases are:
...
@@ -865,14 +868,14 @@ class ValueChoice(ValueChoiceX, Mutable):
...
@@ -865,14 +868,14 @@ class ValueChoice(ValueChoiceX, Mutable):
# FIXME: prior is designed but not supported yet
# FIXME: prior is designed but not supported yet
@
classmethod
@
classmethod
def
create_fixed_module
(
cls
,
candidates
:
List
[
Any
],
*
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
def
create_fixed_module
(
cls
,
candidates
:
List
[
_cand
],
*
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
value
=
get_fixed_value
(
label
)
value
=
get_fixed_value
(
label
)
if
value
not
in
candidates
:
if
value
not
in
candidates
:
raise
ValueError
(
f
'Value
{
value
}
does not belong to the candidates:
{
candidates
}
.'
)
raise
ValueError
(
f
'Value
{
value
}
does not belong to the candidates:
{
candidates
}
.'
)
return
value
return
value
def
__init__
(
self
,
candidates
:
List
[
Any
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
candidates
:
List
[
_cand
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
(
None
,
None
,
None
)
super
().
__init__
()
self
.
candidates
=
candidates
self
.
candidates
=
candidates
self
.
prior
=
prior
or
[
1
/
len
(
candidates
)
for
_
in
range
(
len
(
candidates
))]
self
.
prior
=
prior
or
[
1
/
len
(
candidates
)
for
_
in
range
(
len
(
candidates
))]
assert
abs
(
sum
(
self
.
prior
)
-
1
)
<
1e-5
,
'Sum of prior distribution is not 1.'
assert
abs
(
sum
(
self
.
prior
)
-
1
)
<
1e-5
,
'Sum of prior distribution is not 1.'
...
@@ -894,10 +897,10 @@ class ValueChoice(ValueChoiceX, Mutable):
...
@@ -894,10 +897,10 @@ class ValueChoice(ValueChoiceX, Mutable):
# yield self because self is the only value choice here
# yield self because self is the only value choice here
yield
self
yield
self
def
dry_run
(
self
)
->
Any
:
def
dry_run
(
self
)
->
_cand
:
return
self
.
candidates
[
0
]
return
self
.
candidates
[
0
]
def
_evaluate
(
self
,
values
:
Itera
ble
[
Any
],
dry_run
:
bool
=
False
)
->
Any
:
def
_evaluate
(
self
,
values
:
Itera
tor
[
_cand
],
dry_run
:
bool
=
False
)
->
_cand
:
if
dry_run
:
if
dry_run
:
return
self
.
candidates
[
0
]
return
self
.
candidates
[
0
]
try
:
try
:
...
@@ -986,6 +989,7 @@ class ModelParameterChoice:
...
@@ -986,6 +989,7 @@ class ModelParameterChoice:
Examples
Examples
--------
--------
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
"""
"""
...
@@ -1016,12 +1020,14 @@ class ModelParameterChoice:
...
@@ -1016,12 +1020,14 @@ class ModelParameterChoice:
if
default
not
in
candidates
:
if
default
not
in
candidates
:
# could be callable
# could be callable
try
:
try
:
default
=
default
(
candidates
)
default
=
cast
(
Callable
[[
List
[
ValueType
]],
ValueType
],
default
)
(
candidates
)
except
TypeError
as
e
:
except
TypeError
as
e
:
if
'not callable'
in
str
(
e
):
if
'not callable'
in
str
(
e
):
raise
TypeError
(
"`default` is not in `candidates`, and it's also not callable."
)
raise
TypeError
(
"`default` is not in `candidates`, and it's also not callable."
)
raise
raise
default
=
cast
(
ValueType
,
default
)
label
=
generate_new_label
(
label
)
label
=
generate_new_label
(
label
)
parameter_spec
=
ParameterSpec
(
parameter_spec
=
ParameterSpec
(
label
,
# name
label
,
# name
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment