Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cae4308f
Unverified
Commit
cae4308f
authored
Mar 03, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 03, 2021
Browse files
[Retiarii] Rename APIs and refine documentation (#3404)
parent
d047d6f4
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
310 additions
and
327 deletions
+310
-327
nni/retiarii/graph.py
nni/retiarii/graph.py
+20
-20
nni/retiarii/integration.py
nni/retiarii/integration.py
+1
-1
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+1
-1
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+5
-7
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+115
-129
nni/retiarii/oneshot/__init__.py
nni/retiarii/oneshot/__init__.py
+0
-1
nni/retiarii/oneshot/interface.py
nni/retiarii/oneshot/interface.py
+0
-5
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+5
-0
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+0
-0
nni/retiarii/oneshot/pytorch/enas.py
nni/retiarii/oneshot/pytorch/enas.py
+0
-0
nni/retiarii/oneshot/pytorch/proxyless.py
nni/retiarii/oneshot/pytorch/proxyless.py
+0
-0
nni/retiarii/oneshot/pytorch/random.py
nni/retiarii/oneshot/pytorch/random.py
+0
-0
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+0
-0
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+143
-0
nni/retiarii/utils.py
nni/retiarii/utils.py
+3
-146
test/retiarii_test/darts/darts_model.py
test/retiarii_test/darts/darts_model.py
+2
-2
test/retiarii_test/darts/ops.py
test/retiarii_test/darts/ops.py
+8
-8
test/retiarii_test/darts/test.py
test/retiarii_test/darts/test.py
+4
-4
test/retiarii_test/darts/test_oneshot.py
test/retiarii_test/darts/test_oneshot.py
+1
-1
test/retiarii_test/mnasnet/base_mnasnet.py
test/retiarii_test/mnasnet/base_mnasnet.py
+2
-2
No files found.
nni/retiarii/graph.py
View file @
cae4308f
...
@@ -25,14 +25,14 @@ Type hint for edge's endpoint. The int indicates nodes' order.
...
@@ -25,14 +25,14 @@ Type hint for edge's endpoint. The int indicates nodes' order.
"""
"""
class
TrainingConfig
(
abc
.
ABC
):
class
Evaluator
(
abc
.
ABC
):
"""
"""
Training config of a model. A training config
should define where the training code is, and the configuration of
Evaluator of a model. An evaluator
should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional
training config
might directly import the function and call the function.
For example, functional
evaluator
might directly import the function and call the function.
"""
"""
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -40,15 +40,15 @@ class TrainingConfig(abc.ABC):
...
@@ -40,15 +40,15 @@ class TrainingConfig(abc.ABC):
return
f
'
{
self
.
__class__
.
__name__
}
(
{
items
}
)'
return
f
'
{
self
.
__class__
.
__name__
}
(
{
items
}
)'
@
abc
.
abstractstaticmethod
@
abc
.
abstractstaticmethod
def
_load
(
ir
:
Any
)
->
'
TrainingConfig
'
:
def
_load
(
ir
:
Any
)
->
'
Evaluator
'
:
pass
pass
@
staticmethod
@
staticmethod
def
_load_with_type
(
type_name
:
str
,
ir
:
Any
)
->
'Optional[
TrainingConfig
]'
:
def
_load_with_type
(
type_name
:
str
,
ir
:
Any
)
->
'Optional[
Evaluator
]'
:
if
type_name
==
'_debug_no_trainer'
:
if
type_name
==
'_debug_no_trainer'
:
return
Debug
Training
()
return
Debug
Evaluator
()
config_cls
=
import_
(
type_name
)
config_cls
=
import_
(
type_name
)
assert
issubclass
(
config_cls
,
TrainingConfig
)
assert
issubclass
(
config_cls
,
Evaluator
)
return
config_cls
.
_load
(
ir
)
return
config_cls
.
_load
(
ir
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -83,8 +83,8 @@ class Model:
...
@@ -83,8 +83,8 @@ class Model:
The outermost graph which usually takes dataset as input and feeds output to loss function.
The outermost graph which usually takes dataset as input and feeds output to loss function.
graphs
graphs
All graphs (subgraphs) in this model.
All graphs (subgraphs) in this model.
training_config
evaluator
Training config
Model evaluator
history
history
Mutation history.
Mutation history.
`self` is directly mutated from `self.history[-1]`;
`self` is directly mutated from `self.history[-1]`;
...
@@ -104,7 +104,7 @@ class Model:
...
@@ -104,7 +104,7 @@ class Model:
self
.
_root_graph_name
:
str
=
'_model'
self
.
_root_graph_name
:
str
=
'_model'
self
.
graphs
:
Dict
[
str
,
Graph
]
=
{}
self
.
graphs
:
Dict
[
str
,
Graph
]
=
{}
self
.
training_config
:
Optional
[
TrainingConfig
]
=
None
self
.
evaluator
:
Optional
[
Evaluator
]
=
None
self
.
history
:
List
[
Model
]
=
[]
self
.
history
:
List
[
Model
]
=
[]
...
@@ -113,7 +113,7 @@ class Model:
...
@@ -113,7 +113,7 @@ class Model:
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'Model(model_id=
{
self
.
model_id
}
, status=
{
self
.
status
}
, graphs=
{
list
(
self
.
graphs
.
keys
())
}
, '
+
\
return
f
'Model(model_id=
{
self
.
model_id
}
, status=
{
self
.
status
}
, graphs=
{
list
(
self
.
graphs
.
keys
())
}
, '
+
\
f
'
training_config=
{
self
.
training_config
}
, metric=
{
self
.
metric
}
, intermediate_metrics=
{
self
.
intermediate_metrics
}
)'
f
'
evaluator=
{
self
.
evaluator
}
, metric=
{
self
.
metric
}
, intermediate_metrics=
{
self
.
intermediate_metrics
}
)'
@
property
@
property
def
root_graph
(
self
)
->
'Graph'
:
def
root_graph
(
self
)
->
'Graph'
:
...
@@ -131,7 +131,7 @@ class Model:
...
@@ -131,7 +131,7 @@ class Model:
new_model
=
Model
(
_internal
=
True
)
new_model
=
Model
(
_internal
=
True
)
new_model
.
_root_graph_name
=
self
.
_root_graph_name
new_model
.
_root_graph_name
=
self
.
_root_graph_name
new_model
.
graphs
=
{
name
:
graph
.
_fork_to
(
new_model
)
for
name
,
graph
in
self
.
graphs
.
items
()}
new_model
.
graphs
=
{
name
:
graph
.
_fork_to
(
new_model
)
for
name
,
graph
in
self
.
graphs
.
items
()}
new_model
.
training_config
=
copy
.
deepcopy
(
self
.
training_config
)
# TODO this may be a problem when
training config
is large
new_model
.
evaluator
=
copy
.
deepcopy
(
self
.
evaluator
)
# TODO this may be a problem when
evaluator
is large
new_model
.
history
=
self
.
history
+
[
self
]
new_model
.
history
=
self
.
history
+
[
self
]
return
new_model
return
new_model
...
@@ -139,16 +139,16 @@ class Model:
...
@@ -139,16 +139,16 @@ class Model:
def
_load
(
ir
:
Any
)
->
'Model'
:
def
_load
(
ir
:
Any
)
->
'Model'
:
model
=
Model
(
_internal
=
True
)
model
=
Model
(
_internal
=
True
)
for
graph_name
,
graph_data
in
ir
.
items
():
for
graph_name
,
graph_data
in
ir
.
items
():
if
graph_name
!=
'_
training_config
'
:
if
graph_name
!=
'_
evaluator
'
:
Graph
.
_load
(
model
,
graph_name
,
graph_data
).
_register
()
Graph
.
_load
(
model
,
graph_name
,
graph_data
).
_register
()
model
.
training_config
=
TrainingConfig
.
_load_with_type
(
ir
[
'_
training_config
'
][
'__type__'
],
ir
[
'_
training_config
'
])
model
.
evaluator
=
Evaluator
.
_load_with_type
(
ir
[
'_
evaluator
'
][
'__type__'
],
ir
[
'_
evaluator
'
])
return
model
return
model
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
[
'_
training_config
'
]
=
{
ret
[
'_
evaluator
'
]
=
{
'__type__'
:
get_full_class_name
(
self
.
training_config
.
__class__
),
'__type__'
:
get_full_class_name
(
self
.
evaluator
.
__class__
),
**
self
.
training_config
.
_dump
()
**
self
.
evaluator
.
_dump
()
}
}
return
ret
return
ret
...
@@ -681,10 +681,10 @@ class IllegalGraphError(ValueError):
...
@@ -681,10 +681,10 @@ class IllegalGraphError(ValueError):
json
.
dump
(
graph
,
dump_file
,
indent
=
4
)
json
.
dump
(
graph
,
dump_file
,
indent
=
4
)
class
Debug
Training
(
TrainingConfig
):
class
Debug
Evaluator
(
Evaluator
):
@
staticmethod
@
staticmethod
def
_load
(
ir
:
Any
)
->
'Debug
Training
'
:
def
_load
(
ir
:
Any
)
->
'Debug
Evaluator
'
:
return
Debug
Training
()
return
Debug
Evaluator
()
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
return
{
'__type__'
:
'_debug_no_trainer'
}
return
{
'__type__'
:
'_debug_no_trainer'
}
...
...
nni/retiarii/integration.py
View file @
cae4308f
...
@@ -11,7 +11,7 @@ from .execution.base import BaseExecutionEngine
...
@@ -11,7 +11,7 @@ from .execution.base import BaseExecutionEngine
from
.execution.cgo_engine
import
CGOExecutionEngine
from
.execution.cgo_engine
import
CGOExecutionEngine
from
.execution.api
import
set_execution_engine
from
.execution.api
import
set_execution_engine
from
.integration_api
import
register_advisor
from
.integration_api
import
register_advisor
from
.
utils
import
json_dumps
,
json_loads
from
.
serializer
import
json_dumps
,
json_loads
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
...
nni/retiarii/integration_api.py
View file @
cae4308f
...
@@ -3,7 +3,7 @@ from typing import NewType, Any
...
@@ -3,7 +3,7 @@ from typing import NewType, Any
import
nni
import
nni
from
.
utils
import
json_loads
from
.
serializer
import
json_loads
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
# because it would induce cycled import
...
...
nni/retiarii/nn/pytorch/api.py
View file @
cae4308f
...
@@ -5,7 +5,8 @@ import warnings
...
@@ -5,7 +5,8 @@ import warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
...utils
import
uid
,
add_record
,
del_record
,
Translatable
from
...serializer
import
Translatable
,
basic_unit
from
...utils
import
uid
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
...
@@ -281,21 +282,18 @@ class ValueChoice(Translatable, nn.Module):
...
@@ -281,21 +282,18 @@ class ValueChoice(Translatable, nn.Module):
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
@
basic_unit
class
Placeholder
(
nn
.
Module
):
class
Placeholder
(
nn
.
Module
):
# TODO: docstring
# TODO: docstring
def
__init__
(
self
,
label
,
related_info
):
def
__init__
(
self
,
label
,
**
related_info
):
add_record
(
id
(
self
),
related_info
)
self
.
label
=
label
self
.
label
=
label
self
.
related_info
=
related_info
self
.
related_info
=
related_info
super
(
Placeholder
,
self
).
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
return
x
def
__del__
(
self
):
del_record
(
id
(
self
))
class
ChosenInputs
(
nn
.
Module
):
class
ChosenInputs
(
nn
.
Module
):
"""
"""
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
cae4308f
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
...utils
import
add_record
,
blackbox_module
,
del_record
,
version_larger_equal
from
...serializer
import
basic_unit
from
...serializer
import
transparent_serialize
from
...utils
import
version_larger_equal
# NOTE: support pytorch version >= 1.5.0
# NOTE: support pytorch version >= 1.5.0
...
@@ -36,135 +38,119 @@ if version_larger_equal(torch.__version__, '1.7.0'):
...
@@ -36,135 +38,119 @@ if version_larger_equal(torch.__version__, '1.7.0'):
Module
=
nn
.
Module
Module
=
nn
.
Module
Sequential
=
transparent_serialize
(
nn
.
Sequential
)
class
Sequential
(
nn
.
Sequential
):
ModuleList
=
transparent_serialize
(
nn
.
ModuleList
)
def
__init__
(
self
,
*
args
):
add_record
(
id
(
self
),
{})
Identity
=
basic_unit
(
nn
.
Identity
)
super
(
Sequential
,
self
).
__init__
(
*
args
)
Linear
=
basic_unit
(
nn
.
Linear
)
Conv1d
=
basic_unit
(
nn
.
Conv1d
)
def
__del__
(
self
):
Conv2d
=
basic_unit
(
nn
.
Conv2d
)
del_record
(
id
(
self
))
Conv3d
=
basic_unit
(
nn
.
Conv3d
)
ConvTranspose1d
=
basic_unit
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
basic_unit
(
nn
.
ConvTranspose2d
)
class
ModuleList
(
nn
.
ModuleList
):
ConvTranspose3d
=
basic_unit
(
nn
.
ConvTranspose3d
)
def
__init__
(
self
,
*
args
):
Threshold
=
basic_unit
(
nn
.
Threshold
)
add_record
(
id
(
self
),
{})
ReLU
=
basic_unit
(
nn
.
ReLU
)
super
(
ModuleList
,
self
).
__init__
(
*
args
)
Hardtanh
=
basic_unit
(
nn
.
Hardtanh
)
ReLU6
=
basic_unit
(
nn
.
ReLU6
)
def
__del__
(
self
):
Sigmoid
=
basic_unit
(
nn
.
Sigmoid
)
del_record
(
id
(
self
))
Tanh
=
basic_unit
(
nn
.
Tanh
)
Softmax
=
basic_unit
(
nn
.
Softmax
)
Softmax2d
=
basic_unit
(
nn
.
Softmax2d
)
Identity
=
blackbox_module
(
nn
.
Identity
)
LogSoftmax
=
basic_unit
(
nn
.
LogSoftmax
)
Linear
=
blackbox_module
(
nn
.
Linear
)
ELU
=
basic_unit
(
nn
.
ELU
)
Conv1d
=
blackbox_module
(
nn
.
Conv1d
)
SELU
=
basic_unit
(
nn
.
SELU
)
Conv2d
=
blackbox_module
(
nn
.
Conv2d
)
CELU
=
basic_unit
(
nn
.
CELU
)
Conv3d
=
blackbox_module
(
nn
.
Conv3d
)
GLU
=
basic_unit
(
nn
.
GLU
)
ConvTranspose1d
=
blackbox_module
(
nn
.
ConvTranspose1d
)
GELU
=
basic_unit
(
nn
.
GELU
)
ConvTranspose2d
=
blackbox_module
(
nn
.
ConvTranspose2d
)
Hardshrink
=
basic_unit
(
nn
.
Hardshrink
)
ConvTranspose3d
=
blackbox_module
(
nn
.
ConvTranspose3d
)
LeakyReLU
=
basic_unit
(
nn
.
LeakyReLU
)
Threshold
=
blackbox_module
(
nn
.
Threshold
)
LogSigmoid
=
basic_unit
(
nn
.
LogSigmoid
)
ReLU
=
blackbox_module
(
nn
.
ReLU
)
Softplus
=
basic_unit
(
nn
.
Softplus
)
Hardtanh
=
blackbox_module
(
nn
.
Hardtanh
)
Softshrink
=
basic_unit
(
nn
.
Softshrink
)
ReLU6
=
blackbox_module
(
nn
.
ReLU6
)
MultiheadAttention
=
basic_unit
(
nn
.
MultiheadAttention
)
Sigmoid
=
blackbox_module
(
nn
.
Sigmoid
)
PReLU
=
basic_unit
(
nn
.
PReLU
)
Tanh
=
blackbox_module
(
nn
.
Tanh
)
Softsign
=
basic_unit
(
nn
.
Softsign
)
Softmax
=
blackbox_module
(
nn
.
Softmax
)
Softmin
=
basic_unit
(
nn
.
Softmin
)
Softmax2d
=
blackbox_module
(
nn
.
Softmax2d
)
Tanhshrink
=
basic_unit
(
nn
.
Tanhshrink
)
LogSoftmax
=
blackbox_module
(
nn
.
LogSoftmax
)
RReLU
=
basic_unit
(
nn
.
RReLU
)
ELU
=
blackbox_module
(
nn
.
ELU
)
AvgPool1d
=
basic_unit
(
nn
.
AvgPool1d
)
SELU
=
blackbox_module
(
nn
.
SELU
)
AvgPool2d
=
basic_unit
(
nn
.
AvgPool2d
)
CELU
=
blackbox_module
(
nn
.
CELU
)
AvgPool3d
=
basic_unit
(
nn
.
AvgPool3d
)
GLU
=
blackbox_module
(
nn
.
GLU
)
MaxPool1d
=
basic_unit
(
nn
.
MaxPool1d
)
GELU
=
blackbox_module
(
nn
.
GELU
)
MaxPool2d
=
basic_unit
(
nn
.
MaxPool2d
)
Hardshrink
=
blackbox_module
(
nn
.
Hardshrink
)
MaxPool3d
=
basic_unit
(
nn
.
MaxPool3d
)
LeakyReLU
=
blackbox_module
(
nn
.
LeakyReLU
)
MaxUnpool1d
=
basic_unit
(
nn
.
MaxUnpool1d
)
LogSigmoid
=
blackbox_module
(
nn
.
LogSigmoid
)
MaxUnpool2d
=
basic_unit
(
nn
.
MaxUnpool2d
)
Softplus
=
blackbox_module
(
nn
.
Softplus
)
MaxUnpool3d
=
basic_unit
(
nn
.
MaxUnpool3d
)
Softshrink
=
blackbox_module
(
nn
.
Softshrink
)
FractionalMaxPool2d
=
basic_unit
(
nn
.
FractionalMaxPool2d
)
MultiheadAttention
=
blackbox_module
(
nn
.
MultiheadAttention
)
FractionalMaxPool3d
=
basic_unit
(
nn
.
FractionalMaxPool3d
)
PReLU
=
blackbox_module
(
nn
.
PReLU
)
LPPool1d
=
basic_unit
(
nn
.
LPPool1d
)
Softsign
=
blackbox_module
(
nn
.
Softsign
)
LPPool2d
=
basic_unit
(
nn
.
LPPool2d
)
Softmin
=
blackbox_module
(
nn
.
Softmin
)
LocalResponseNorm
=
basic_unit
(
nn
.
LocalResponseNorm
)
Tanhshrink
=
blackbox_module
(
nn
.
Tanhshrink
)
BatchNorm1d
=
basic_unit
(
nn
.
BatchNorm1d
)
RReLU
=
blackbox_module
(
nn
.
RReLU
)
BatchNorm2d
=
basic_unit
(
nn
.
BatchNorm2d
)
AvgPool1d
=
blackbox_module
(
nn
.
AvgPool1d
)
BatchNorm3d
=
basic_unit
(
nn
.
BatchNorm3d
)
AvgPool2d
=
blackbox_module
(
nn
.
AvgPool2d
)
InstanceNorm1d
=
basic_unit
(
nn
.
InstanceNorm1d
)
AvgPool3d
=
blackbox_module
(
nn
.
AvgPool3d
)
InstanceNorm2d
=
basic_unit
(
nn
.
InstanceNorm2d
)
MaxPool1d
=
blackbox_module
(
nn
.
MaxPool1d
)
InstanceNorm3d
=
basic_unit
(
nn
.
InstanceNorm3d
)
MaxPool2d
=
blackbox_module
(
nn
.
MaxPool2d
)
LayerNorm
=
basic_unit
(
nn
.
LayerNorm
)
MaxPool3d
=
blackbox_module
(
nn
.
MaxPool3d
)
GroupNorm
=
basic_unit
(
nn
.
GroupNorm
)
MaxUnpool1d
=
blackbox_module
(
nn
.
MaxUnpool1d
)
SyncBatchNorm
=
basic_unit
(
nn
.
SyncBatchNorm
)
MaxUnpool2d
=
blackbox_module
(
nn
.
MaxUnpool2d
)
Dropout
=
basic_unit
(
nn
.
Dropout
)
MaxUnpool3d
=
blackbox_module
(
nn
.
MaxUnpool3d
)
Dropout2d
=
basic_unit
(
nn
.
Dropout2d
)
FractionalMaxPool2d
=
blackbox_module
(
nn
.
FractionalMaxPool2d
)
Dropout3d
=
basic_unit
(
nn
.
Dropout3d
)
FractionalMaxPool3d
=
blackbox_module
(
nn
.
FractionalMaxPool3d
)
AlphaDropout
=
basic_unit
(
nn
.
AlphaDropout
)
LPPool1d
=
blackbox_module
(
nn
.
LPPool1d
)
FeatureAlphaDropout
=
basic_unit
(
nn
.
FeatureAlphaDropout
)
LPPool2d
=
blackbox_module
(
nn
.
LPPool2d
)
ReflectionPad1d
=
basic_unit
(
nn
.
ReflectionPad1d
)
LocalResponseNorm
=
blackbox_module
(
nn
.
LocalResponseNorm
)
ReflectionPad2d
=
basic_unit
(
nn
.
ReflectionPad2d
)
BatchNorm1d
=
blackbox_module
(
nn
.
BatchNorm1d
)
ReplicationPad2d
=
basic_unit
(
nn
.
ReplicationPad2d
)
BatchNorm2d
=
blackbox_module
(
nn
.
BatchNorm2d
)
ReplicationPad1d
=
basic_unit
(
nn
.
ReplicationPad1d
)
BatchNorm3d
=
blackbox_module
(
nn
.
BatchNorm3d
)
ReplicationPad3d
=
basic_unit
(
nn
.
ReplicationPad3d
)
InstanceNorm1d
=
blackbox_module
(
nn
.
InstanceNorm1d
)
CrossMapLRN2d
=
basic_unit
(
nn
.
CrossMapLRN2d
)
InstanceNorm2d
=
blackbox_module
(
nn
.
InstanceNorm2d
)
Embedding
=
basic_unit
(
nn
.
Embedding
)
InstanceNorm3d
=
blackbox_module
(
nn
.
InstanceNorm3d
)
EmbeddingBag
=
basic_unit
(
nn
.
EmbeddingBag
)
LayerNorm
=
blackbox_module
(
nn
.
LayerNorm
)
RNNBase
=
basic_unit
(
nn
.
RNNBase
)
GroupNorm
=
blackbox_module
(
nn
.
GroupNorm
)
RNN
=
basic_unit
(
nn
.
RNN
)
SyncBatchNorm
=
blackbox_module
(
nn
.
SyncBatchNorm
)
LSTM
=
basic_unit
(
nn
.
LSTM
)
Dropout
=
blackbox_module
(
nn
.
Dropout
)
GRU
=
basic_unit
(
nn
.
GRU
)
Dropout2d
=
blackbox_module
(
nn
.
Dropout2d
)
RNNCellBase
=
basic_unit
(
nn
.
RNNCellBase
)
Dropout3d
=
blackbox_module
(
nn
.
Dropout3d
)
RNNCell
=
basic_unit
(
nn
.
RNNCell
)
AlphaDropout
=
blackbox_module
(
nn
.
AlphaDropout
)
LSTMCell
=
basic_unit
(
nn
.
LSTMCell
)
FeatureAlphaDropout
=
blackbox_module
(
nn
.
FeatureAlphaDropout
)
GRUCell
=
basic_unit
(
nn
.
GRUCell
)
ReflectionPad1d
=
blackbox_module
(
nn
.
ReflectionPad1d
)
PixelShuffle
=
basic_unit
(
nn
.
PixelShuffle
)
ReflectionPad2d
=
blackbox_module
(
nn
.
ReflectionPad2d
)
Upsample
=
basic_unit
(
nn
.
Upsample
)
ReplicationPad2d
=
blackbox_module
(
nn
.
ReplicationPad2d
)
UpsamplingNearest2d
=
basic_unit
(
nn
.
UpsamplingNearest2d
)
ReplicationPad1d
=
blackbox_module
(
nn
.
ReplicationPad1d
)
UpsamplingBilinear2d
=
basic_unit
(
nn
.
UpsamplingBilinear2d
)
ReplicationPad3d
=
blackbox_module
(
nn
.
ReplicationPad3d
)
PairwiseDistance
=
basic_unit
(
nn
.
PairwiseDistance
)
CrossMapLRN2d
=
blackbox_module
(
nn
.
CrossMapLRN2d
)
AdaptiveMaxPool1d
=
basic_unit
(
nn
.
AdaptiveMaxPool1d
)
Embedding
=
blackbox_module
(
nn
.
Embedding
)
AdaptiveMaxPool2d
=
basic_unit
(
nn
.
AdaptiveMaxPool2d
)
EmbeddingBag
=
blackbox_module
(
nn
.
EmbeddingBag
)
AdaptiveMaxPool3d
=
basic_unit
(
nn
.
AdaptiveMaxPool3d
)
RNNBase
=
blackbox_module
(
nn
.
RNNBase
)
AdaptiveAvgPool1d
=
basic_unit
(
nn
.
AdaptiveAvgPool1d
)
RNN
=
blackbox_module
(
nn
.
RNN
)
AdaptiveAvgPool2d
=
basic_unit
(
nn
.
AdaptiveAvgPool2d
)
LSTM
=
blackbox_module
(
nn
.
LSTM
)
AdaptiveAvgPool3d
=
basic_unit
(
nn
.
AdaptiveAvgPool3d
)
GRU
=
blackbox_module
(
nn
.
GRU
)
TripletMarginLoss
=
basic_unit
(
nn
.
TripletMarginLoss
)
RNNCellBase
=
blackbox_module
(
nn
.
RNNCellBase
)
ZeroPad2d
=
basic_unit
(
nn
.
ZeroPad2d
)
RNNCell
=
blackbox_module
(
nn
.
RNNCell
)
ConstantPad1d
=
basic_unit
(
nn
.
ConstantPad1d
)
LSTMCell
=
blackbox_module
(
nn
.
LSTMCell
)
ConstantPad2d
=
basic_unit
(
nn
.
ConstantPad2d
)
GRUCell
=
blackbox_module
(
nn
.
GRUCell
)
ConstantPad3d
=
basic_unit
(
nn
.
ConstantPad3d
)
PixelShuffle
=
blackbox_module
(
nn
.
PixelShuffle
)
Bilinear
=
basic_unit
(
nn
.
Bilinear
)
Upsample
=
blackbox_module
(
nn
.
Upsample
)
CosineSimilarity
=
basic_unit
(
nn
.
CosineSimilarity
)
UpsamplingNearest2d
=
blackbox_module
(
nn
.
UpsamplingNearest2d
)
Unfold
=
basic_unit
(
nn
.
Unfold
)
UpsamplingBilinear2d
=
blackbox_module
(
nn
.
UpsamplingBilinear2d
)
Fold
=
basic_unit
(
nn
.
Fold
)
PairwiseDistance
=
blackbox_module
(
nn
.
PairwiseDistance
)
AdaptiveLogSoftmaxWithLoss
=
basic_unit
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
AdaptiveMaxPool1d
=
blackbox_module
(
nn
.
AdaptiveMaxPool1d
)
TransformerEncoder
=
basic_unit
(
nn
.
TransformerEncoder
)
AdaptiveMaxPool2d
=
blackbox_module
(
nn
.
AdaptiveMaxPool2d
)
TransformerDecoder
=
basic_unit
(
nn
.
TransformerDecoder
)
AdaptiveMaxPool3d
=
blackbox_module
(
nn
.
AdaptiveMaxPool3d
)
TransformerEncoderLayer
=
basic_unit
(
nn
.
TransformerEncoderLayer
)
AdaptiveAvgPool1d
=
blackbox_module
(
nn
.
AdaptiveAvgPool1d
)
TransformerDecoderLayer
=
basic_unit
(
nn
.
TransformerDecoderLayer
)
AdaptiveAvgPool2d
=
blackbox_module
(
nn
.
AdaptiveAvgPool2d
)
Transformer
=
basic_unit
(
nn
.
Transformer
)
AdaptiveAvgPool3d
=
blackbox_module
(
nn
.
AdaptiveAvgPool3d
)
Flatten
=
basic_unit
(
nn
.
Flatten
)
TripletMarginLoss
=
blackbox_module
(
nn
.
TripletMarginLoss
)
Hardsigmoid
=
basic_unit
(
nn
.
Hardsigmoid
)
ZeroPad2d
=
blackbox_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
blackbox_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
blackbox_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
blackbox_module
(
nn
.
ConstantPad3d
)
Bilinear
=
blackbox_module
(
nn
.
Bilinear
)
CosineSimilarity
=
blackbox_module
(
nn
.
CosineSimilarity
)
Unfold
=
blackbox_module
(
nn
.
Unfold
)
Fold
=
blackbox_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
blackbox_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
blackbox_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
blackbox_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
blackbox_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
blackbox_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
blackbox_module
(
nn
.
Transformer
)
Flatten
=
blackbox_module
(
nn
.
Flatten
)
Hardsigmoid
=
blackbox_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
b
lackbox_module
(
nn
.
Hardswish
)
Hardswish
=
b
asic_unit
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
b
lackbox_module
(
nn
.
SiLU
)
SiLU
=
b
asic_unit
(
nn
.
SiLU
)
Unflatten
=
b
lackbox_module
(
nn
.
Unflatten
)
Unflatten
=
b
asic_unit
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
b
lackbox_module
(
nn
.
TripletMarginWithDistanceLoss
)
TripletMarginWithDistanceLoss
=
b
asic_unit
(
nn
.
TripletMarginWithDistanceLoss
)
nni/retiarii/
trainer
/__init__.py
→
nni/retiarii/
oneshot
/__init__.py
View file @
cae4308f
from
.functional
import
FunctionalTrainer
from
.interface
import
BaseOneShotTrainer
from
.interface
import
BaseOneShotTrainer
nni/retiarii/
trainer
/interface.py
→
nni/retiarii/
oneshot
/interface.py
View file @
cae4308f
...
@@ -2,11 +2,6 @@ import abc
...
@@ -2,11 +2,6 @@ import abc
from
typing
import
Any
from
typing
import
Any
class
BaseTrainer
(
abc
.
ABC
):
# Deprecated class
pass
class
BaseOneShotTrainer
(
abc
.
ABC
):
class
BaseOneShotTrainer
(
abc
.
ABC
):
"""
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
...
...
nni/retiarii/
trainer
/pytorch/__init__.py
→
nni/retiarii/
oneshot
/pytorch/__init__.py
View file @
cae4308f
from
.base
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
.darts
import
DartsTrainer
from
.darts
import
DartsTrainer
from
.enas
import
EnasTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
from
.proxyless
import
ProxylessTrainer
from
.random
import
RandomTrainer
,
SinglePathTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.utils
import
replace_input_choice
,
replace_layer_choice
nni/retiarii/
trainer
/pytorch/darts.py
→
nni/retiarii/
oneshot
/pytorch/darts.py
View file @
cae4308f
File moved
nni/retiarii/
trainer
/pytorch/enas.py
→
nni/retiarii/
oneshot
/pytorch/enas.py
View file @
cae4308f
File moved
nni/retiarii/
trainer
/pytorch/proxyless.py
→
nni/retiarii/
oneshot
/pytorch/proxyless.py
View file @
cae4308f
File moved
nni/retiarii/
trainer
/pytorch/random.py
→
nni/retiarii/
oneshot
/pytorch/random.py
View file @
cae4308f
File moved
nni/retiarii/
trainer
/pytorch/utils.py
→
nni/retiarii/
oneshot
/pytorch/utils.py
View file @
cae4308f
File moved
nni/retiarii/serializer.py
0 → 100644
View file @
cae4308f
import
abc
import
functools
import
inspect
from
typing
import
Any
import
json_tricks
from
.utils
import
get_full_class_name
,
get_module_name
,
import_
def
get_init_parameters_or_fail
(
obj
,
silently
=
False
):
if
hasattr
(
obj
,
'_init_parameters'
):
return
obj
.
_init_parameters
elif
silently
:
return
None
else
:
raise
ValueError
(
f
'Object
{
obj
}
needs to be serializable but `_init_parameters` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use serialize or @serialize_cls.'
)
### This is a patch of json-tricks to make it more useful to us ###
def
_serialize_class_instance_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
try
:
# FIXME: raise error
if
hasattr
(
obj
,
'__class__'
):
return
{
'__type__'
:
get_full_class_name
(
obj
.
__class__
),
'arguments'
:
get_init_parameters_or_fail
(
obj
)
}
except
ValueError
:
pass
return
obj
def
_serialize_class_instance_decode
(
obj
):
if
isinstance
(
obj
,
dict
)
and
'__type__'
in
obj
and
'arguments'
in
obj
:
return
import_
(
obj
[
'__type__'
])(
**
obj
[
'arguments'
])
return
obj
def
_type_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
isinstance
(
obj
,
type
):
return
{
'__typename__'
:
get_full_class_name
(
obj
,
relocate_module
=
True
)}
return
obj
def
_type_decode
(
obj
):
if
isinstance
(
obj
,
dict
)
and
'__typename__'
in
obj
:
return
import_
(
obj
[
'__typename__'
])
return
obj
json_loads
=
functools
.
partial
(
json_tricks
.
loads
,
extra_obj_pairs_hooks
=
[
_serialize_class_instance_decode
,
_type_decode
])
json_dumps
=
functools
.
partial
(
json_tricks
.
dumps
,
extra_obj_encoders
=
[
_serialize_class_instance_encode
,
_type_encode
])
json_load
=
functools
.
partial
(
json_tricks
.
load
,
extra_obj_pairs_hooks
=
[
_serialize_class_instance_decode
,
_type_decode
])
json_dump
=
functools
.
partial
(
json_tricks
.
dump
,
extra_obj_encoders
=
[
_serialize_class_instance_encode
,
_type_encode
])
### End of json-tricks patch ###
class
Translatable
(
abc
.
ABC
):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@
abc
.
abstractmethod
def
_translate
(
self
)
->
Any
:
pass
def
_create_wrapper_cls
(
cls
,
store_init_parameters
=
True
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
if
store_init_parameters
:
argname_list
=
list
(
inspect
.
signature
(
cls
.
__init__
).
parameters
.
keys
())[
1
:]
full_args
=
{}
full_args
.
update
(
kwargs
)
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# translate parameters
args
=
list
(
args
)
for
i
,
value
in
enumerate
(
args
):
if
isinstance
(
value
,
Translatable
):
args
[
i
]
=
value
.
_translate
()
for
i
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
Translatable
):
kwargs
[
i
]
=
value
.
_translate
()
self
.
_init_parameters
=
full_args
else
:
self
.
_init_parameters
=
{}
super
().
__init__
(
*
args
,
**
kwargs
)
wrapper
.
__module__
=
get_module_name
(
cls
)
wrapper
.
__name__
=
cls
.
__name__
wrapper
.
__qualname__
=
cls
.
__qualname__
wrapper
.
__init__
.
__doc__
=
cls
.
__init__
.
__doc__
return
wrapper
def
serialize_cls
(
cls
):
"""
To create an serializable class.
"""
return
_create_wrapper_cls
(
cls
)
def
transparent_serialize
(
cls
):
"""
Wrap a module but does not record parameters. For internal use only.
"""
return
_create_wrapper_cls
(
cls
,
store_init_parameters
=
False
)
def
serialize
(
cls
,
*
args
,
**
kwargs
):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
return
serialize_cls
(
cls
)(
*
args
,
**
kwargs
)
def
basic_unit
(
cls
):
"""
To wrap a module as a basic unit, to stop it from parsing and make it mutate-able.
"""
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
return
serialize_cls
(
cls
)
nni/retiarii/utils.py
View file @
cae4308f
import
abc
import
functools
import
inspect
import
inspect
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
from
typing
import
Any
from
pathlib
import
Path
from
pathlib
import
Path
import
json_tricks
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
Any
:
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
Any
:
if
target
is
None
:
if
target
is
None
:
...
@@ -23,145 +19,6 @@ def version_larger_equal(a: str, b: str) -> bool:
...
@@ -23,145 +19,6 @@ def version_larger_equal(a: str, b: str) -> bool:
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
### This is a patch of json-tricks to make it more useful to us ###
def
_blackbox_class_instance_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
hasattr
(
obj
,
'__class__'
)
and
hasattr
(
obj
,
'__init_parameters__'
):
return
{
'__type__'
:
get_full_class_name
(
obj
.
__class__
),
'arguments'
:
obj
.
__init_parameters__
}
return
obj
def
_blackbox_class_instance_decode
(
obj
):
if
isinstance
(
obj
,
dict
)
and
'__type__'
in
obj
and
'arguments'
in
obj
:
return
import_
(
obj
[
'__type__'
])(
**
obj
[
'arguments'
])
return
obj
def
_type_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
isinstance
(
obj
,
type
):
return
{
'__typename__'
:
get_full_class_name
(
obj
,
relocate_module
=
True
)}
return
obj
def
_type_decode
(
obj
):
if
isinstance
(
obj
,
dict
)
and
'__typename__'
in
obj
:
return
import_
(
obj
[
'__typename__'
])
return
obj
json_loads
=
functools
.
partial
(
json_tricks
.
loads
,
extra_obj_pairs_hooks
=
[
_blackbox_class_instance_decode
,
_type_decode
])
json_dumps
=
functools
.
partial
(
json_tricks
.
dumps
,
extra_obj_encoders
=
[
_blackbox_class_instance_encode
,
_type_encode
])
json_load
=
functools
.
partial
(
json_tricks
.
load
,
extra_obj_pairs_hooks
=
[
_blackbox_class_instance_decode
,
_type_decode
])
json_dump
=
functools
.
partial
(
json_tricks
.
dump
,
extra_obj_encoders
=
[
_blackbox_class_instance_encode
,
_type_encode
])
### End of json-tricks patch ###
_records
=
{}
def
get_records
():
global
_records
return
_records
def
clear_records
():
global
_records
_records
=
{}
def
add_record
(
key
,
value
):
"""
"""
global
_records
if
_records
is
not
None
:
assert
key
not
in
_records
,
f
'
{
key
}
already in _records. Conflict:
{
_records
[
key
]
}
'
_records
[
key
]
=
value
def
del_record
(
key
):
global
_records
if
_records
is
not
None
:
_records
.
pop
(
key
,
None
)
class
Translatable
(
abc
.
ABC
):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@
abc
.
abstractmethod
def
_translate
(
self
)
->
Any
:
pass
def
_blackbox_cls
(
cls
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
argname_list
=
list
(
inspect
.
signature
(
cls
.
__init__
).
parameters
.
keys
())[
1
:]
full_args
=
{}
full_args
.
update
(
kwargs
)
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# translate parameters
args
=
list
(
args
)
for
i
,
value
in
enumerate
(
args
):
if
isinstance
(
value
,
Translatable
):
args
[
i
]
=
value
.
_translate
()
for
i
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
Translatable
):
kwargs
[
i
]
=
value
.
_translate
()
add_record
(
id
(
self
),
full_args
)
# for compatibility. Will remove soon.
self
.
__init_parameters__
=
full_args
super
().
__init__
(
*
args
,
**
kwargs
)
def
__del__
(
self
):
del_record
(
id
(
self
))
wrapper
.
__module__
=
_get_module_name
(
cls
)
wrapper
.
__name__
=
cls
.
__name__
wrapper
.
__qualname__
=
cls
.
__qualname__
wrapper
.
__init__
.
__doc__
=
cls
.
__init__
.
__doc__
return
wrapper
def
blackbox
(
cls
,
*
args
,
**
kwargs
):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
return
_blackbox_cls
(
cls
)(
*
args
,
**
kwargs
)
def
blackbox_module
(
cls
):
"""
Register a module. Use it as a decorator.
"""
return
_blackbox_cls
(
cls
)
def
register_trainer
(
cls
):
"""
Register a trainer. Use it as a decorator.
"""
return
_blackbox_cls
(
cls
)
_last_uid
=
defaultdict
(
int
)
_last_uid
=
defaultdict
(
int
)
...
@@ -170,7 +27,7 @@ def uid(namespace: str = 'default') -> int:
...
@@ -170,7 +27,7 @@ def uid(namespace: str = 'default') -> int:
return
_last_uid
[
namespace
]
return
_last_uid
[
namespace
]
def
_
get_module_name
(
cls
):
def
get_module_name
(
cls
):
module_name
=
cls
.
__module__
module_name
=
cls
.
__module__
if
module_name
==
'__main__'
:
if
module_name
==
'__main__'
:
# infer the module name with inspect
# infer the module name with inspect
...
@@ -180,7 +37,7 @@ def _get_module_name(cls):
...
@@ -180,7 +37,7 @@ def _get_module_name(cls):
main_file_path
=
Path
(
inspect
.
getsourcefile
(
frm
[
0
]))
main_file_path
=
Path
(
inspect
.
getsourcefile
(
frm
[
0
]))
if
main_file_path
.
parents
[
0
]
!=
Path
(
'.'
):
if
main_file_path
.
parents
[
0
]
!=
Path
(
'.'
):
raise
RuntimeError
(
f
'You are using "
{
main_file_path
}
" to launch your experiment, '
raise
RuntimeError
(
f
'You are using "
{
main_file_path
}
" to launch your experiment, '
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
module_name
=
main_file_path
.
stem
break
break
...
@@ -195,5 +52,5 @@ def _get_module_name(cls):
...
@@ -195,5 +52,5 @@ def _get_module_name(cls):
def
get_full_class_name
(
cls
,
relocate_module
=
False
):
def
get_full_class_name
(
cls
,
relocate_module
=
False
):
module_name
=
_
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
module_name
=
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
return
module_name
+
'.'
+
cls
.
__name__
return
module_name
+
'.'
+
cls
.
__name__
test/retiarii_test/darts/darts_model.py
View file @
cae4308f
...
@@ -7,9 +7,9 @@ import torch.nn as torch_nn
...
@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import
ops
import
ops
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
@
b
lackbox_module
@
b
asic_unit
class
AuxiliaryHead
(
nn
.
Module
):
class
AuxiliaryHead
(
nn
.
Module
):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
...
...
test/retiarii_test/darts/ops.py
View file @
cae4308f
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
@
b
lackbox_module
@
b
asic_unit
class
DropPath
(
nn
.
Module
):
class
DropPath
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
def
__init__
(
self
,
p
=
0.
):
"""
"""
...
@@ -24,7 +24,7 @@ class DropPath(nn.Module):
...
@@ -24,7 +24,7 @@ class DropPath(nn.Module):
return
x
return
x
@
b
lackbox_module
@
b
asic_unit
class
PoolBN
(
nn
.
Module
):
class
PoolBN
(
nn
.
Module
):
"""
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
...
@@ -45,7 +45,7 @@ class PoolBN(nn.Module):
...
@@ -45,7 +45,7 @@ class PoolBN(nn.Module):
out
=
self
.
bn
(
out
)
out
=
self
.
bn
(
out
)
return
out
return
out
@
b
lackbox_module
@
b
asic_unit
class
StdConv
(
nn
.
Module
):
class
StdConv
(
nn
.
Module
):
"""
"""
Standard conv: ReLU - Conv - BN
Standard conv: ReLU - Conv - BN
...
@@ -61,7 +61,7 @@ class StdConv(nn.Module):
...
@@ -61,7 +61,7 @@ class StdConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
b
lackbox_module
@
b
asic_unit
class
FacConv
(
nn
.
Module
):
class
FacConv
(
nn
.
Module
):
"""
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
...
@@ -78,7 +78,7 @@ class FacConv(nn.Module):
...
@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
b
lackbox_module
@
b
asic_unit
class
DilConv
(
nn
.
Module
):
class
DilConv
(
nn
.
Module
):
"""
"""
(Dilated) depthwise separable conv.
(Dilated) depthwise separable conv.
...
@@ -98,7 +98,7 @@ class DilConv(nn.Module):
...
@@ -98,7 +98,7 @@ class DilConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
b
lackbox_module
@
b
asic_unit
class
SepConv
(
nn
.
Module
):
class
SepConv
(
nn
.
Module
):
"""
"""
Depthwise separable conv.
Depthwise separable conv.
...
@@ -114,7 +114,7 @@ class SepConv(nn.Module):
...
@@ -114,7 +114,7 @@ class SepConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
b
lackbox_module
@
b
asic_unit
class
FactorizedReduce
(
nn
.
Module
):
class
FactorizedReduce
(
nn
.
Module
):
"""
"""
Reduce feature map size by factorized pointwise (stride=2).
Reduce feature map size by factorized pointwise (stride=2).
...
...
test/retiarii_test/darts/test.py
View file @
cae4308f
...
@@ -4,9 +4,9 @@ import sys
...
@@ -4,9 +4,9 @@ import sys
import
torch
import
torch
from
pathlib
import
Path
from
pathlib
import
Path
import
nni.retiarii.
traine
r.pytorch.lightning
as
pl
import
nni.retiarii.
evaluato
r.pytorch.lightning
as
pl
import
nni.retiarii.strategy
as
strategy
import
nni.retiarii.strategy
as
strategy
from
nni.retiarii
import
blackbox_module
as
bm
from
nni.retiarii
import
serialize
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
...
@@ -27,8 +27,8 @@ if __name__ == '__main__':
...
@@ -27,8 +27,8 @@ if __name__ == '__main__':
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
)),
])
])
train_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
train_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
test_dataset
=
bm
(
CIFAR10
)(
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
test_dataset
=
serialize
(
CIFAR10
,
root
=
'data/cifar10'
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
trainer
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.2
)
max_epochs
=
1
,
limit_train_batches
=
0.2
)
...
...
test/retiarii_test/darts/test_oneshot.py
View file @
cae4308f
...
@@ -9,7 +9,7 @@ from torchvision import transforms
...
@@ -9,7 +9,7 @@ from torchvision import transforms
from
torchvision.datasets
import
CIFAR10
from
torchvision.datasets
import
CIFAR10
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
from
nni.retiarii.
trainer
.pytorch
import
DartsTrainer
from
nni.retiarii.
oneshot
.pytorch
import
DartsTrainer
from
darts_model
import
CNN
from
darts_model
import
CNN
...
...
test/retiarii_test/mnasnet/base_mnasnet.py
View file @
cae4308f
from
nni.retiarii
import
b
lackbox_module
from
nni.retiarii
import
b
asic_unit
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
import
warnings
import
warnings
...
@@ -148,7 +148,7 @@ class MNASNet(nn.Module):
...
@@ -148,7 +148,7 @@ class MNASNet(nn.Module):
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for
filter_size
,
exp_ratio
,
stride
in
zip
(
base_filter_sizes
,
exp_ratios
,
strides
):
for
filter_size
,
exp_ratio
,
stride
in
zip
(
base_filter_sizes
,
exp_ratios
,
strides
):
# TODO: restrict that "choose" can only be used within mutator
# TODO: restrict that "choose" can only be used within mutator
ph
=
nn
.
Placeholder
(
label
=
f
'mutable_
{
count
}
'
,
related_info
=
{
ph
=
nn
.
Placeholder
(
label
=
f
'mutable_
{
count
}
'
,
**
{
'kernel_size_options'
:
[
1
,
3
,
5
],
'kernel_size_options'
:
[
1
,
3
,
5
],
'n_layer_options'
:
[
1
,
2
,
3
,
4
],
'n_layer_options'
:
[
1
,
2
,
3
,
4
],
'op_type_options'
:
[
'__mutated__.base_mnasnet.RegularConv'
,
'op_type_options'
:
[
'__mutated__.base_mnasnet.RegularConv'
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment