Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8d5f643c
Unverified
Commit
8d5f643c
authored
Mar 23, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 23, 2022
Browse files
Hyper-parameter Choice in Retiarii (#4609)
parent
ba771871
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
268 additions
and
21 deletions
+268
-21
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+147
-4
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+8
-1
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+2
-1
nni/retiarii/utils.py
nni/retiarii/utils.py
+51
-14
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+60
-1
No files found.
nni/retiarii/nn/pytorch/api.py
View file @
8d5f643c
...
@@ -5,18 +5,19 @@ import math
...
@@ -5,18 +5,19 @@ import math
import
itertools
import
itertools
import
operator
import
operator
import
warnings
import
warnings
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
,
Sequence
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
Translatable
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
,
ModelNamespace
,
NoContextError
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'ModelParameterChoice'
,
'Placeholder'
,
'ChosenInputs'
]
class
LayerChoice
(
Mutable
):
class
LayerChoice
(
Mutable
):
...
@@ -870,7 +871,6 @@ class ValueChoice(ValueChoiceX, Mutable):
...
@@ -870,7 +871,6 @@ class ValueChoice(ValueChoiceX, Mutable):
self
.
prior
=
prior
or
[
1
/
len
(
candidates
)
for
_
in
range
(
len
(
candidates
))]
self
.
prior
=
prior
or
[
1
/
len
(
candidates
)
for
_
in
range
(
len
(
candidates
))]
assert
abs
(
sum
(
self
.
prior
)
-
1
)
<
1e-5
,
'Sum of prior distribution is not 1.'
assert
abs
(
sum
(
self
.
prior
)
-
1
)
<
1e-5
,
'Sum of prior distribution is not 1.'
self
.
_label
=
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
self
.
_accessor
=
[]
@
property
@
property
def
label
(
self
):
def
label
(
self
):
...
@@ -906,6 +906,149 @@ class ValueChoice(ValueChoiceX, Mutable):
...
@@ -906,6 +906,149 @@ class ValueChoice(ValueChoiceX, Mutable):
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
ValueType
=
TypeVar
(
'ValueType'
)
class
ModelParameterChoice
:
"""ModelParameterChoice chooses one hyper-parameter from ``candidates``.
.. attention::
This API is internal, and does not guarantee forward-compatibility.
It's quite similar to :class:`ValueChoice`, but unlike :class:`ValueChoice`,
it always returns a fixed value, even at the construction of base model.
This makes it highly flexible (e.g., can be used in for-loop, if-condition, as argument of any function). For example: ::
self.has_auxiliary_head = ModelParameterChoice([False, True])
# this will raise error if you use `ValueChoice`
if self.has_auxiliary_head is True: # or self.has_auxiliary_head
self.auxiliary_head = Head()
else:
self.auxiliary_head = None
print(type(self.has_auxiliary_head)) # <class 'bool'>
The working mechanism of :class:`ModelParameterChoice` is that, it registers itself
in the ``model_wrapper``, as a hyper-parameter of the model, and then returns the value specified with ``default``.
At base model construction, the default value will be used (as a mocked hyper-parameter).
In trial, the hyper-parameter selected by strategy will be used.
Although flexible, we still recommend using :class:`ValueChoice` in favor of :class:`ModelParameterChoice`,
because information are lost when using :class:`ModelParameterChoice` in exchange of its flexibility,
making it incompatible with one-shot strategies and non-python execution engines.
.. warning::
:class:`ModelParameterChoice` can NOT be nested.
.. tip::
Although called :class:`ModelParameterChoice`, it's meant to tune hyper-parameter of architecture.
It's NOT used to tune model-training hyper-parameters like ``learning_rate``.
If you need to tune ``learning_rate``, please use :class:`ValueChoice` on arguments of :class:`nni.retiarii.Evaluator`.
Parameters
----------
candidates : list of any
List of values to choose from.
prior : list of float
Prior distribution to sample from. Currently has no effect.
default : Callable[[List[Any]], Any] or Any
Function that selects one from ``candidates``, or a candidate.
Use :meth:`ModelParameterChoice.FIRST` or :meth:`ModelParameterChoice.LAST` to take the first or last item.
Default: :meth:`ModelParameterChoice.FIRST`
label : str
Identifier of the value choice.
Warnings
--------
:class:`ModelParameterChoice` is incompatible with one-shot strategies and non-python execution engines.
Sometimes, the same search space implemented **without** :class:`ModelParameterChoice` can be simpler, and explored
with more types of search strategies. For example, the following usages are equivalent: ::
# with ModelParameterChoice
depth = nn.ModelParameterChoice(list(range(3, 10)))
blocks = []
for i in range(depth):
blocks.append(Block())
# w/o HyperParmaeterChoice
blocks = Repeat(Block(), (3, 9))
Examples
--------
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
"""
# FIXME: fix signature in docs
# FIXME: prior is designed but not supported yet
def
__new__
(
cls
,
candidates
:
List
[
ValueType
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
default
:
Union
[
Callable
[[
List
[
ValueType
]],
ValueType
],
ValueType
]
=
None
,
label
:
Optional
[
str
]
=
None
)
->
ValueType
:
# Actually, creating a `ModelParameterChoice` never creates one.
# It always return a fixed value, and register a ParameterSpec
if
default
is
None
:
default
=
cls
.
FIRST
try
:
return
cls
.
create_fixed_module
(
candidates
,
label
=
label
)
except
NoContextError
:
return
cls
.
create_default
(
candidates
,
default
,
label
)
@
staticmethod
def
create_default
(
candidates
:
List
[
ValueType
],
default
:
Union
[
Callable
[[
List
[
ValueType
]],
ValueType
],
ValueType
],
label
:
Optional
[
str
])
->
ValueType
:
if
default
not
in
candidates
:
# could be callable
try
:
default
=
default
(
candidates
)
except
TypeError
as
e
:
if
'not callable'
in
str
(
e
):
raise
TypeError
(
"`default` is not in `candidates`, and it's also not callable."
)
raise
label
=
generate_new_label
(
label
)
parameter_spec
=
ParameterSpec
(
label
,
# name
'choice'
,
# TODO: support more types
candidates
,
# value
(
label
,),
# we don't have nested now
True
,
# yes, categorical
)
# there could be duplicates. Dedup is done in mutator
ModelNamespace
.
current_context
().
parameter_specs
.
append
(
parameter_spec
)
return
default
@
classmethod
def
create_fixed_module
(
cls
,
candidates
:
List
[
ValueType
],
*
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
ValueType
:
# same as ValueChoice
value
=
get_fixed_value
(
label
)
if
value
not
in
candidates
:
raise
ValueError
(
f
'Value
{
value
}
does not belong to the candidates:
{
candidates
}
.'
)
return
value
@
staticmethod
def
FIRST
(
sequence
:
Sequence
[
ValueType
])
->
ValueType
:
"""Get the first item of sequence. Useful in ``default`` argument."""
return
sequence
[
0
]
@
staticmethod
def
LAST
(
sequence
:
Sequence
[
ValueType
])
->
ValueType
:
"""Get the last item of sequence. Useful in ``default`` argument."""
return
sequence
[
-
1
]
@
basic_unit
@
basic_unit
class
Placeholder
(
nn
.
Module
):
class
Placeholder
(
nn
.
Module
):
"""
"""
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
8d5f643c
...
@@ -11,7 +11,7 @@ from nni.common.serializer import is_traceable
...
@@ -11,7 +11,7 @@ from nni.common.serializer import is_traceable
from
nni.retiarii.graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
,
Evaluator
from
nni.retiarii.graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
,
Evaluator
from
nni.retiarii.mutator
import
Mutator
from
nni.retiarii.mutator
import
Mutator
from
nni.retiarii.serializer
import
is_basic_unit
,
is_model_wrapped
from
nni.retiarii.serializer
import
is_basic_unit
,
is_model_wrapped
from
nni.retiarii.utils
import
uid
from
nni.retiarii.utils
import
ModelNamespace
,
uid
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
ValueChoiceX
,
Placeholder
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
ValueChoiceX
,
Placeholder
from
.component
import
NasBench101Cell
,
NasBench101Mutator
from
.component
import
NasBench101Cell
,
NasBench101Mutator
...
@@ -285,6 +285,13 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
...
@@ -285,6 +285,13 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
else
:
else
:
model
.
python_init_params
=
{}
model
.
python_init_params
=
{}
# hyper-parameter choice
namespace
:
ModelNamespace
=
pytorch_model
.
_model_namespace
for
param_spec
in
namespace
.
parameter_specs
:
assert
param_spec
.
categorical
and
param_spec
.
type
==
'choice'
node
=
graph
.
add_node
(
f
'param_spec_
{
param_spec
.
name
}
'
,
'ModelParameterChoice'
,
{
'candidates'
:
param_spec
.
values
})
node
.
label
=
param_spec
.
name
for
name
,
module
in
pytorch_model
.
named_modules
():
for
name
,
module
in
pytorch_model
.
named_modules
():
# tricky case: value choice that serves as parameters are stored in traced arguments
# tricky case: value choice that serves as parameters are stored in traced arguments
if
is_basic_unit
(
module
):
if
is_basic_unit
(
module
):
...
...
nni/retiarii/serializer.py
View file @
8d5f643c
...
@@ -120,7 +120,8 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
...
@@ -120,7 +120,8 @@ def model_wrapper(cls: T) -> Union[T, Traceable]:
class
reset_wrapper
(
wrapper
):
class
reset_wrapper
(
wrapper
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
with
ModelNamespace
():
self
.
_model_namespace
=
ModelNamespace
()
with
self
.
_model_namespace
:
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
_copy_class_wrapper_attributes
(
wrapper
,
reset_wrapper
)
...
...
nni/retiarii/utils.py
View file @
8d5f643c
...
@@ -9,6 +9,8 @@ from contextlib import contextmanager
...
@@ -9,6 +9,8 @@ from contextlib import contextmanager
from
typing
import
Any
,
List
,
Dict
from
typing
import
Any
,
List
,
Dict
from
pathlib
import
Path
from
pathlib
import
Path
from
nni.common.hpo_utils
import
ParameterSpec
__all__
=
[
'NoContextError'
,
'ContextStack'
,
'ModelNamespace'
]
__all__
=
[
'NoContextError'
,
'ContextStack'
,
'ModelNamespace'
]
...
@@ -111,43 +113,78 @@ class ContextStack:
...
@@ -111,43 +113,78 @@ class ContextStack:
class
ModelNamespace
:
class
ModelNamespace
:
"""
"""
To create an individual namespace for models to enable automatic numbering.
To create an individual namespace for models:
1. to enable automatic numbering;
2. to trace general information (like creation of hyper-parameters) of model.
A namespace is bounded to a key. Namespace bounded to different keys are completed isolated.
Namespace can have sub-namespaces (with the same key). The numbering will be chained (e.g., ``model_1_4_2``).
"""
"""
def
__init__
(
self
,
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
):
def
__init__
(
self
,
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
):
# for example, key: "model_wrapper"
# for example, key: "model_wrapper"
self
.
key
=
key
self
.
key
=
key
# the "path" of current name
# By default, it's ``[]``
# If a ``@model_wrapper`` is nested inside a model_wrapper, it will become something like ``[1, 3, 2]``.
# See ``__enter__``.
self
.
name_path
:
List
[
int
]
=
[]
# parameter specs.
# Currently only used trace calls of ModelParameterChoice.
self
.
parameter_specs
:
List
[
ParameterSpec
]
=
[]
def
__enter__
(
self
):
def
__enter__
(
self
):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try
:
try
:
current_context
=
ContextStack
.
top
(
self
.
key
)
parent_context
:
'ModelNamespace'
=
ModelNamespace
.
current_context
(
self
.
key
)
next_uid
=
uid
(
self
.
_simple_name
(
self
.
key
,
current_context
))
next_uid
=
uid
(
parent_context
.
_simple_name
())
ContextStack
.
push
(
self
.
key
,
current_context
+
[
next_uid
])
self
.
name_path
=
parent_context
.
name_path
+
[
next_uid
]
reset_uid
(
self
.
_simple_name
(
self
.
key
,
current_context
+
[
next_uid
]))
ContextStack
.
push
(
self
.
key
,
self
)
reset_uid
(
self
.
_simple_name
())
except
NoContextError
:
except
NoContextError
:
ContextStack
.
push
(
self
.
key
,
[])
# not found, no existing namespace
reset_uid
(
self
.
_simple_name
(
self
.
key
,
[]))
self
.
name_path
=
[]
ContextStack
.
push
(
self
.
key
,
self
)
reset_uid
(
self
.
_simple_name
())
def
__exit__
(
self
,
*
args
,
**
kwargs
):
def
__exit__
(
self
,
*
args
,
**
kwargs
):
ContextStack
.
pop
(
self
.
key
)
ContextStack
.
pop
(
self
.
key
)
def
_simple_name
(
self
)
->
str
:
return
self
.
key
+
''
.
join
([
'_'
+
str
(
k
)
for
k
in
self
.
name_path
])
def
__repr__
(
self
):
return
f
'ModelNamespace(name=
{
self
.
_simple_name
()
}
, num_specs=
{
len
(
self
.
parameter_specs
)
}
)'
# Access the current context in the model #
@
staticmethod
def
current_context
(
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
)
->
'ModelNamespace'
:
"""Get the current context in key."""
try
:
return
ContextStack
.
top
(
key
)
except
NoContextError
:
raise
NoContextError
(
'ModelNamespace context is missing. You might have forgotten to use `@model_wrapper`.'
)
@
staticmethod
@
staticmethod
def
next_label
(
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
)
->
str
:
def
next_label
(
key
:
str
=
_DEFAULT_MODEL_NAMESPACE
)
->
str
:
"""Get the next label for API calls, with automatic numbering."""
try
:
try
:
current_context
=
ContextStack
.
top
(
key
)
current_context
=
ContextStack
.
top
(
key
)
except
NoContextError
:
except
NoContextError
:
# fallback to use "default" namespace
# fallback to use "default" namespace
return
ModelNamespace
.
_simple_name
(
'default'
,
[
uid
()])
# it won't be registered
warnings
.
warn
(
'ModelNamespace is missing. You might have forgotten to use `@model_wrapper`. '
next_uid
=
uid
(
ModelNamespace
.
_simple_name
(
key
,
current_context
)
)
'Some features might not work. This will be an error in future releases.'
,
RuntimeWarning
)
return
ModelNamespace
.
_simple_name
(
key
,
current_context
+
[
next_uid
]
)
current_context
=
ModelNamespace
(
'default'
)
@
staticmethod
next_uid
=
uid
(
current_context
.
_simple_name
())
def
_simple_name
(
key
:
str
,
lst
:
List
[
Any
])
->
str
:
return
current_context
.
_simple_name
()
+
'_'
+
str
(
next_uid
)
return
key
+
''
.
join
([
'_'
+
str
(
k
)
for
k
in
lst
])
def
get_current_context
(
key
:
str
)
->
Any
:
def
get_current_context
(
key
:
str
)
->
Any
:
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
8d5f643c
...
@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
...
@@ -17,7 +17,7 @@ from nni.retiarii.graph import Model
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
from
nni.retiarii.nn.pytorch.api
import
ValueChoice
from
nni.retiarii.nn.pytorch.mutator
import
process_evaluator_mutations
,
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.nn.pytorch.mutator
import
process_evaluator_mutations
,
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
,
original_state_dict_hooks
from
nni.retiarii.utils
import
ContextStack
,
NoContextError
,
original_state_dict_hooks
class
EnumerateSampler
(
Sampler
):
class
EnumerateSampler
(
Sampler
):
...
@@ -849,6 +849,65 @@ class Python(GraphIR):
...
@@ -849,6 +849,65 @@ class Python(GraphIR):
@
unittest
.
skip
@
unittest
.
skip
def
test_valuechoice_getitem_functional_expression
(
self
):
...
def
test_valuechoice_getitem_functional_expression
(
self
):
...
def
test_hyperparameter_choice
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
aux
=
nn
.
ModelParameterChoice
([
False
,
True
])
def
forward
(
self
,
x
):
return
x
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
)
sampler
=
EnumerateSampler
()
model1
=
_apply_all_mutators
(
model
,
mutators
,
sampler
)
model2
=
_apply_all_mutators
(
model
,
mutators
,
sampler
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
).
aux
,
False
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model2
).
aux
,
True
)
def
test_hyperparameter_choice_parameter
(
self
):
class
Inner
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
aux
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
,
nn
.
ModelParameterChoice
([
64
,
128
,
256
],
label
=
'a'
),
3
,
3
)
)
def
forward
(
self
):
return
self
.
aux
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
choice
=
nn
.
ModelParameterChoice
([
64
,
128
,
256
],
label
=
'a'
)
self
.
inner
=
Inner
()
def
forward
(
self
):
param
=
self
.
inner
()
assert
param
.
size
(
1
)
==
self
.
choice
return
param
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
)
sampler
=
RandomSampler
()
result_pool
=
set
()
for
_
in
range
(
20
):
model
=
_apply_all_mutators
(
model
,
mutators
,
sampler
)
result
=
self
.
_get_converted_pytorch_model
(
model
)()
result_pool
.
add
(
result
.
size
(
1
))
self
.
assertSetEqual
(
result_pool
,
{
64
,
128
,
256
})
def
test_hyperparameter_choice_no_model_wrapper
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
choice
=
nn
.
ModelParameterChoice
([
64
,
128
,
256
],
label
=
'a'
)
with
self
.
assertRaises
(
NoContextError
):
model
=
Net
()
def
test_cell_loose_end
(
self
):
def
test_cell_loose_end
(
self
):
@
model_wrapper
@
model_wrapper
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment