Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
18962129
Unverified
Commit
18962129
authored
Apr 25, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 25, 2022
Browse files
Add license header and typehints for NAS (#4774)
parent
8c2f717d
Changes
96
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
75 additions
and
29 deletions
+75
-29
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+8
-8
nni/retiarii/strategy/_rl_impl.py
nni/retiarii/strategy/_rl_impl.py
+19
-7
nni/retiarii/strategy/base.py
nni/retiarii/strategy/base.py
+1
-1
nni/retiarii/strategy/bruteforce.py
nni/retiarii/strategy/bruteforce.py
+3
-1
nni/retiarii/strategy/rl.py
nni/retiarii/strategy/rl.py
+3
-0
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-2
nni/retiarii/strategy/utils.py
nni/retiarii/strategy/utils.py
+0
-1
nni/retiarii/trial_entry.py
nni/retiarii/trial_entry.py
+2
-0
nni/retiarii/utils.py
nni/retiarii/utils.py
+6
-4
nni/typehint.py
nni/typehint.py
+1
-1
pipelines/fast-test.yml
pipelines/fast-test.yml
+3
-0
pyrightconfig.json
pyrightconfig.json
+8
-4
test/pytest.ini
test/pytest.ini
+1
-0
test/ut/retiarii/mnist_pytorch.json
test/ut/retiarii/mnist_pytorch.json
+4
-0
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+2
-0
test/vso_tools/trigger_import.py
test/vso_tools/trigger_import.py
+10
-0
No files found.
nni/retiarii/serializer.py
View file @
18962129
...
...
@@ -4,9 +4,9 @@
import
inspect
import
os
import
warnings
from
typing
import
Any
,
TypeVar
,
Union
from
typing
import
Any
,
TypeVar
,
Type
from
nni.common.serializer
import
Traceable
,
is_traceable
,
is_wrapped_with_trace
,
trace
,
_copy_class_wrapper_attributes
from
nni.common.serializer
import
is_traceable
,
is_wrapped_with_trace
,
trace
,
_copy_class_wrapper_attributes
from
.utils
import
ModelNamespace
__all__
=
[
'get_init_parameters_or_fail'
,
'serialize'
,
'serialize_cls'
,
'basic_unit'
,
'model_wrapper'
,
...
...
@@ -48,7 +48,7 @@ def serialize_cls(cls):
return
trace
(
cls
)
def
basic_unit
(
cls
:
T
,
basic_unit_tag
:
bool
=
True
)
->
Union
[
T
,
Traceable
]
:
def
basic_unit
(
cls
:
T
,
basic_unit_tag
:
bool
=
True
)
->
T
:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
...
...
@@ -75,17 +75,17 @@ def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
return
cls
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
assert
issubclass
(
cls
,
nn
.
Module
),
'When using @basic_unit, the class must be a subclass of nn.Module.'
# type: ignore
cls
=
trace
(
cls
)
cls
.
_nni_basic_unit
=
basic_unit_tag
cls
.
_nni_basic_unit
=
basic_unit_tag
# type: ignore
_torchscript_patch
(
cls
)
return
cls
def
model_wrapper
(
cls
:
T
)
->
Union
[
T
,
Traceable
]
:
def
model_wrapper
(
cls
:
T
)
->
T
:
"""
Wrap the base model (search space). For example,
...
...
@@ -113,7 +113,7 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
return
cls
import
torch.nn
as
nn
assert
issubclass
(
cls
,
nn
.
Module
)
assert
issubclass
(
cls
,
nn
.
Module
)
# type: ignore
# subclass can still use trace info
wrapper
=
trace
(
cls
,
inheritable
=
True
)
...
...
@@ -146,7 +146,7 @@ def is_model_wrapped(cls_or_instance) -> bool:
return
getattr
(
cls_or_instance
,
'_nni_model_wrapper'
,
False
)
def
_check_wrapped
(
cls
:
T
,
rewrap
:
str
)
->
bool
:
def
_check_wrapped
(
cls
:
T
ype
,
rewrap
:
str
)
->
bool
:
wrapped
=
None
if
is_model_wrapped
(
cls
):
wrapped
=
'model_wrapper'
...
...
nni/retiarii/strategy/_rl_impl.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# This file might cause import error for those who didn't install RL-related dependencies
import
logging
import
threading
from
multiprocessing.pool
import
ThreadPool
from
typing
import
Tuple
import
gym
import
numpy
as
np
import
tianshou
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
gym
import
spaces
from
tianshou.data
import
to_torch
from
tianshou.env.worker
import
EnvWorker
from
nni.typehint
import
TypedDict
from
.utils
import
get_targeted_model
from
..graph
import
ModelStatus
from
..execution
import
submit_models
,
wait_models
...
...
@@ -76,8 +83,13 @@ class MultiThreadEnvWorker(EnvWorker):
self
.
pool
.
terminate
()
return
self
.
env
.
close
()
class
ObservationType
(
TypedDict
):
action_history
:
np
.
ndarray
cur_step
:
int
action_dim
:
int
class
ModelEvaluationEnv
(
gym
.
Env
):
class
ModelEvaluationEnv
(
gym
.
Env
[
ObservationType
,
int
]):
def
__init__
(
self
,
base_model
,
mutators
,
search_space
):
self
.
base_model
=
base_model
self
.
mutators
=
mutators
...
...
@@ -98,7 +110,7 @@ class ModelEvaluationEnv(gym.Env):
def
action_space
(
self
):
return
spaces
.
Discrete
(
self
.
action_dim
)
def
reset
(
self
):
def
reset
(
self
)
->
ObservationType
:
self
.
action_history
=
np
.
zeros
(
self
.
num_steps
,
dtype
=
np
.
int32
)
self
.
cur_step
=
0
self
.
sample
=
{}
...
...
@@ -108,14 +120,14 @@ class ModelEvaluationEnv(gym.Env):
'action_dim'
:
len
(
self
.
search_space
[
self
.
ss_keys
[
self
.
cur_step
]])
}
def
step
(
self
,
action
)
:
def
step
(
self
,
action
:
int
)
->
Tuple
[
ObservationType
,
float
,
bool
,
dict
]
:
cur_key
=
self
.
ss_keys
[
self
.
cur_step
]
assert
action
<
len
(
self
.
search_space
[
cur_key
]),
\
f
'Current action
{
action
}
out of range
{
self
.
search_space
[
cur_key
]
}
.'
self
.
action_history
[
self
.
cur_step
]
=
action
self
.
sample
[
cur_key
]
=
self
.
search_space
[
cur_key
][
action
]
self
.
cur_step
+=
1
obs
=
{
obs
:
ObservationType
=
{
'action_history'
:
self
.
action_history
,
'cur_step'
:
self
.
cur_step
,
'action_dim'
:
len
(
self
.
search_space
[
self
.
ss_keys
[
self
.
cur_step
]])
\
...
...
@@ -129,7 +141,7 @@ class ModelEvaluationEnv(gym.Env):
wait_models
(
model
)
if
model
.
status
==
ModelStatus
.
Failed
:
return
self
.
reset
(),
0.
,
False
,
{}
rew
=
model
.
metric
rew
=
float
(
model
.
metric
)
_logger
.
info
(
f
'Model metric received as reward:
{
rew
}
'
)
return
obs
,
rew
,
True
,
{}
else
:
...
...
@@ -147,7 +159,7 @@ class Preprocessor(nn.Module):
self
.
rnn
=
nn
.
LSTM
(
hidden_dim
,
hidden_dim
,
num_layers
,
batch_first
=
True
)
def
forward
(
self
,
obs
):
seq
=
nn
.
functional
.
pad
(
obs
[
'action_history'
]
+
1
,
(
1
,
1
))
# pad the start token and end token
seq
=
F
.
pad
(
obs
[
'action_history'
]
+
1
,
(
1
,
1
))
# pad the start token and end token
# end token is used to avoid out-of-range of v_s_. Will not actually affect BP.
seq
=
self
.
embedding
(
seq
.
long
())
feature
,
_
=
self
.
rnn
(
seq
)
...
...
@@ -167,7 +179,7 @@ class Actor(nn.Module):
# to take care of choices with different number of options
mask
=
torch
.
arange
(
self
.
action_dim
).
expand
(
len
(
out
),
self
.
action_dim
)
>=
obs
[
'action_dim'
].
unsqueeze
(
1
)
out
[
mask
.
to
(
out
.
device
)]
=
float
(
'-inf'
)
return
nn
.
functional
.
softmax
(
out
,
dim
=-
1
),
kwargs
.
get
(
'state'
,
None
)
return
F
.
softmax
(
out
,
dim
=-
1
),
kwargs
.
get
(
'state'
,
None
)
class
Critic
(
nn
.
Module
):
...
...
nni/retiarii/strategy/base.py
View file @
18962129
...
...
@@ -14,5 +14,5 @@ class BaseStrategy(abc.ABC):
def
run
(
self
,
base_model
:
Model
,
applied_mutators
:
List
[
Mutator
])
->
None
:
pass
def
export_top_models
(
self
)
->
List
[
Any
]:
def
export_top_models
(
self
,
top_k
:
int
)
->
List
[
Any
]:
raise
NotImplementedError
(
'"export_top_models" is not implemented.'
)
nni/retiarii/strategy/bruteforce.py
View file @
18962129
...
...
@@ -6,7 +6,7 @@ import itertools
import
logging
import
random
import
time
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Sequence
,
Optional
from
..
import
InvalidMutation
,
Sampler
,
submit_models
,
query_available_resources
,
budget_exhausted
from
.base
import
BaseStrategy
...
...
@@ -30,6 +30,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history
=
set
()
search_space_values
=
copy
.
deepcopy
(
list
(
search_space
.
values
()))
while
True
:
selected
:
Optional
[
Sequence
[
int
]]
=
None
for
retry_count
in
range
(
retries
):
selected
=
[
random
.
choice
(
v
)
for
v
in
search_space_values
]
if
not
dedup
:
...
...
@@ -41,6 +42,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
if
retry_count
+
1
==
retries
:
_logger
.
debug
(
'Random generation has run out of patience. There is nothing to search. Exiting.'
)
return
assert
selected
is
not
None
,
'Retry attempts exhausted.'
yield
{
key
:
value
for
key
,
value
in
zip
(
keys
,
selected
)}
...
...
nni/retiarii/strategy/rl.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
typing
import
Optional
,
Callable
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
18962129
...
...
@@ -3,6 +3,7 @@
import
logging
import
time
from
typing
import
Optional
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
...
...
@@ -15,8 +16,8 @@ _logger = logging.getLogger(__name__)
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
cur_sample
=
None
self
.
index
=
None
self
.
cur_sample
:
Optional
[
dict
]
=
None
self
.
index
:
Optional
[
int
]
=
None
self
.
total_parameters
=
{}
def
update_sample_space
(
self
,
sample_space
):
...
...
@@ -34,6 +35,7 @@ class TPESampler(Sampler):
self
.
tpe_tuner
.
receive_trial_result
(
model_id
,
self
.
total_parameters
[
model_id
],
result
)
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
assert
isinstance
(
self
.
index
,
int
)
and
isinstance
(
self
.
cur_sample
,
dict
)
chosen
=
self
.
cur_sample
[
str
(
self
.
index
)]
self
.
index
+=
1
return
chosen
...
...
nni/retiarii/strategy/utils.py
View file @
18962129
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
logging
from
typing
import
Dict
,
Any
,
List
...
...
nni/retiarii/trial_entry.py
View file @
18962129
...
...
@@ -25,4 +25,6 @@ if __name__ == '__main__':
elif
args
.
exec
==
'benchmark'
:
from
.execution.benchmark
import
BenchmarkExecutionEngine
engine
=
BenchmarkExecutionEngine
else
:
raise
ValueError
(
f
'Unrecognized benchmark name:
{
args
.
exec
}
'
)
engine
.
trial_execute_graph
()
nni/retiarii/utils.py
View file @
18962129
...
...
@@ -6,7 +6,7 @@ import itertools
import
warnings
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Dict
from
typing
import
Any
,
List
,
Dict
,
cast
from
pathlib
import
Path
from
nni.common.hpo_utils
import
ParameterSpec
...
...
@@ -41,9 +41,10 @@ def get_module_name(cls_or_func):
if
module_name
==
'__main__'
:
# infer the module name with inspect
for
frm
in
inspect
.
stack
():
if
inspect
.
getmodule
(
frm
[
0
]).
__name__
==
'__main__'
:
module
=
inspect
.
getmodule
(
frm
[
0
])
if
module
is
not
None
and
module
.
__name__
==
'__main__'
:
# main module found
main_file_path
=
Path
(
inspect
.
getsourcefile
(
frm
[
0
]))
main_file_path
=
Path
(
cast
(
str
,
inspect
.
getsourcefile
(
frm
[
0
]))
)
if
not
Path
().
samefile
(
main_file_path
.
parent
):
raise
RuntimeError
(
f
'You are using "
{
main_file_path
}
" to launch your experiment, '
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
...
...
@@ -227,6 +228,7 @@ def original_state_dict_hooks(model: Any):
supernet_style_state_dict = model.state_dict()
"""
import
torch.utils.hooks
import
torch.nn
as
nn
assert
isinstance
(
model
,
nn
.
Module
),
'PyTorch is the only supported framework for now.'
...
...
@@ -297,8 +299,8 @@ def original_state_dict_hooks(model: Any):
raise
KeyError
(
f
'"
{
src
}
" not in state dict, but found in mapping.'
)
destination
.
update
(
result
)
hooks
:
List
[
torch
.
utils
.
hooks
.
RemovableHandle
]
=
[]
try
:
hooks
=
[]
hooks
.
append
(
model
.
_register_load_state_dict_pre_hook
(
load_state_dict_hook
))
hooks
.
append
(
model
.
_register_state_dict_hook
(
state_dict_hook
))
yield
...
...
nni/typehint.py
View file @
18962129
...
...
@@ -6,7 +6,7 @@ Types for static checking.
"""
__all__
=
[
'Literal'
,
'Literal'
,
'TypedDict'
,
'Parameters'
,
'SearchSpace'
,
'TrialMetric'
,
'TrialRecord'
,
]
...
...
pipelines/fast-test.yml
View file @
18962129
...
...
@@ -64,6 +64,9 @@ stages:
python -m pip install "typing-extensions>=3.10"
displayName
:
Resolve dependency version
-
script
:
python test/vso_tools/trigger_import.py
displayName
:
Trigger import
-
script
:
|
python -m pylint --rcfile pylintrc nni
displayName
:
pylint
...
...
pyrightconfig.json
View file @
18962129
...
...
@@ -3,10 +3,13 @@
"nni/algorithms"
,
"nni/common/device.py"
,
"nni/common/graph_utils.py"
,
"nni/common/serializer.py"
,
"nni/compression"
,
"nni/nas"
,
"nni/retiarii"
,
"nni/nas/tensorflow"
,
"nni/nas/pytorch"
,
"nni/retiarii/execution/cgo_engine.py"
,
"nni/retiarii/execution/logical_optimizer"
,
"nni/retiarii/evaluator/pytorch/cgo"
,
"nni/retiarii/oneshot"
,
"nni/smartparam.py"
,
"nni/tools/annotation"
,
"nni/tools/gpu_tool"
,
...
...
@@ -14,5 +17,6 @@
"nni/tools/nnictl"
,
"nni/tools/trial_tool"
],
"reportMissingImports"
:
false
"reportMissingImports"
:
false
,
"reportPrivateImportUsage"
:
false
}
test/pytest.ini
View file @
18962129
...
...
@@ -4,4 +4,5 @@ filterwarnings =
ignore:Using
key
to
access
the
identifier
of:DeprecationWarning
ignore:layer_choice.choices
is
deprecated.:DeprecationWarning
ignore:The
truth
value
of
an
empty
array
is
ambiguous.:DeprecationWarning
ignore:`np.bool`
is
a
deprecated
alias
for
the
builtin
`bool`:DeprecationWarning
ignore:nni.retiarii.serialize
is
deprecated
and
will
be
removed
in
future
release.:DeprecationWarning
test/ut/retiarii/mnist_pytorch.json
View file @
18962129
...
...
@@ -36,5 +36,9 @@
{
"head"
:
[
"conv2"
,
null
],
"tail"
:
[
"pool2"
,
null
]},
{
"head"
:
[
"pool2"
,
null
],
"tail"
:
[
"_outputs"
,
0
]}
]
},
"_evaluator"
:
{
"type"
:
"DebugEvaluator"
}
}
test/ut/retiarii/test_engine.py
View file @
18962129
...
...
@@ -9,6 +9,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
from
nni.retiarii.execution
import
set_execution_engine
from
nni.retiarii.execution.base
import
BaseExecutionEngine
from
nni.retiarii.execution.python
import
PurePythonExecutionEngine
from
nni.retiarii.graph
import
DebugEvaluator
from
nni.retiarii.integration
import
RetiariiAdvisor
...
...
@@ -51,6 +52,7 @@ class EngineTest(unittest.TestCase):
'edges'
:
[]
}
})
model
.
evaluator
=
DebugEvaluator
()
model
.
python_class
=
object
submit_models
(
model
,
model
)
...
...
test/vso_tools/trigger_import.py
0 → 100644
View file @
18962129
"""Trigger import of some modules to write some caches,
so that static analysis (e.g., pyright) can know the type."""
import
os
import
sys
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'../../'
))
import
nni
import
nni.retiarii.nn.pytorch
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment